mostlyai-mock 0.0.1__tar.gz → 0.0.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mostlyai_mock-0.0.1 → mostlyai_mock-0.0.5}/PKG-INFO +41 -22
- {mostlyai_mock-0.0.1 → mostlyai_mock-0.0.5}/README.md +40 -20
- {mostlyai_mock-0.0.1 → mostlyai_mock-0.0.5}/mostlyai/mock/__init__.py +1 -1
- {mostlyai_mock-0.0.1 → mostlyai_mock-0.0.5}/mostlyai/mock/core.py +187 -48
- {mostlyai_mock-0.0.1 → mostlyai_mock-0.0.5}/pyproject.toml +1 -1
- mostlyai_mock-0.0.1/LICENSE_HEADER +0 -13
- {mostlyai_mock-0.0.1 → mostlyai_mock-0.0.5}/.gitignore +0 -0
- {mostlyai_mock-0.0.1 → mostlyai_mock-0.0.5}/LICENSE +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: mostlyai-mock
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.5
|
4
4
|
Summary: Synthetic Mock Data
|
5
5
|
Project-URL: homepage, https://github.com/mostly-ai/mostlyai-mock
|
6
6
|
Project-URL: repository, https://github.com/mostly-ai/mostlyai-mock
|
@@ -8,7 +8,6 @@ Project-URL: documentation, https://mostly-ai.github.io/mostlyai-mock/
|
|
8
8
|
Author-email: MOSTLY AI <dev@mostly.ai>
|
9
9
|
License-Expression: Apache-2.0
|
10
10
|
License-File: LICENSE
|
11
|
-
License-File: LICENSE_HEADER
|
12
11
|
Requires-Python: >=3.10
|
13
12
|
Requires-Dist: litellm>=1.67.0
|
14
13
|
Requires-Dist: numpy>=1.26.3
|
@@ -46,12 +45,13 @@ tables = {
|
|
46
45
|
"columns": {
|
47
46
|
"nationality": {"prompt": "2-letter code for the nationality", "dtype": "string"},
|
48
47
|
"name": {"prompt": "first name and last name of the guest", "dtype": "string"},
|
49
|
-
"gender": {"
|
48
|
+
"gender": {"dtype": "category", "values": ["male", "female"]},
|
50
49
|
"age": {"prompt": "age in years; min: 18, max: 80; avg: 25", "dtype": "integer"},
|
51
50
|
"date_of_birth": {"prompt": "date of birth", "dtype": "date"},
|
52
51
|
"checkin_time": {"prompt": "the check in timestamp of the guest; may 2025", "dtype": "datetime"},
|
53
52
|
"is_vip": {"prompt": "is the guest a VIP", "dtype": "boolean"},
|
54
53
|
"price_per_night": {"prompt": "price paid per night, in EUR", "dtype": "float"},
|
54
|
+
"room_number": {"prompt": "room number", "dtype": "integer", "values": [101, 102, 103, 201, 202, 203, 204]}
|
55
55
|
},
|
56
56
|
}
|
57
57
|
}
|
@@ -65,34 +65,53 @@ print(df)
|
|
65
65
|
from mostlyai import mock
|
66
66
|
|
67
67
|
tables = {
|
68
|
-
"
|
69
|
-
"description": "
|
68
|
+
"customers": {
|
69
|
+
"description": "Customers of a hardware store",
|
70
70
|
"columns": {
|
71
|
-
"
|
72
|
-
"name": {"prompt": "first name and last name of the
|
71
|
+
"customer_id": {"prompt": "the unique id of the customer", "dtype": "integer"},
|
72
|
+
"name": {"prompt": "first name and last name of the customer", "dtype": "string"},
|
73
|
+
},
|
74
|
+
"primary_key": "customer_id",
|
75
|
+
},
|
76
|
+
"orders": {
|
77
|
+
"description": "Orders of a Customer",
|
78
|
+
"columns": {
|
79
|
+
"customer_id": {"prompt": "the customer id for that order", "dtype": "integer"},
|
80
|
+
"order_id": {"prompt": "the unique id of the order", "dtype": "string"},
|
81
|
+
"text": {"prompt": "order text description", "dtype": "string"},
|
82
|
+
"amount": {"prompt": "order amount in USD", "dtype": "float"},
|
73
83
|
},
|
74
|
-
"primary_key": "
|
84
|
+
"primary_key": "order_id",
|
85
|
+
"foreign_keys": [
|
86
|
+
{
|
87
|
+
"column": "customer_id",
|
88
|
+
"referenced_table": "customers",
|
89
|
+
"description": "each customer has anywhere between 1 and 3 orders",
|
90
|
+
}
|
91
|
+
],
|
75
92
|
},
|
76
|
-
"
|
77
|
-
"description": "
|
93
|
+
"items": {
|
94
|
+
"description": "Items in an Order",
|
78
95
|
"columns": {
|
79
|
-
"
|
80
|
-
"
|
81
|
-
"
|
82
|
-
"
|
96
|
+
"item_id": {"prompt": "the unique id of the item", "dtype": "string"},
|
97
|
+
"order_id": {"prompt": "the order id for that item", "dtype": "string"},
|
98
|
+
"name": {"prompt": "the name of the item", "dtype": "string"},
|
99
|
+
"price": {"prompt": "the price of the item in USD", "dtype": "float"},
|
83
100
|
},
|
84
101
|
"foreign_keys": [
|
85
102
|
{
|
86
|
-
"column": "
|
87
|
-
"referenced_table": "
|
88
|
-
"description": "each
|
103
|
+
"column": "order_id",
|
104
|
+
"referenced_table": "orders",
|
105
|
+
"description": "each order has between 2 and 5 items",
|
89
106
|
}
|
90
107
|
],
|
91
108
|
},
|
92
109
|
}
|
93
|
-
data = mock.sample(tables=tables, sample_size=
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
print(
|
110
|
+
data = mock.sample(tables=tables, sample_size=2, model="openai/gpt-4.1")
|
111
|
+
df_customers = data["customers"]
|
112
|
+
df_orders = data["orders"]
|
113
|
+
df_items = data["items"]
|
114
|
+
print(df_customers)
|
115
|
+
print(df_orders)
|
116
|
+
print(df_items)
|
98
117
|
```
|
@@ -27,12 +27,13 @@ tables = {
|
|
27
27
|
"columns": {
|
28
28
|
"nationality": {"prompt": "2-letter code for the nationality", "dtype": "string"},
|
29
29
|
"name": {"prompt": "first name and last name of the guest", "dtype": "string"},
|
30
|
-
"gender": {"
|
30
|
+
"gender": {"dtype": "category", "values": ["male", "female"]},
|
31
31
|
"age": {"prompt": "age in years; min: 18, max: 80; avg: 25", "dtype": "integer"},
|
32
32
|
"date_of_birth": {"prompt": "date of birth", "dtype": "date"},
|
33
33
|
"checkin_time": {"prompt": "the check in timestamp of the guest; may 2025", "dtype": "datetime"},
|
34
34
|
"is_vip": {"prompt": "is the guest a VIP", "dtype": "boolean"},
|
35
35
|
"price_per_night": {"prompt": "price paid per night, in EUR", "dtype": "float"},
|
36
|
+
"room_number": {"prompt": "room number", "dtype": "integer", "values": [101, 102, 103, 201, 202, 203, 204]}
|
36
37
|
},
|
37
38
|
}
|
38
39
|
}
|
@@ -46,34 +47,53 @@ print(df)
|
|
46
47
|
from mostlyai import mock
|
47
48
|
|
48
49
|
tables = {
|
49
|
-
"
|
50
|
-
"description": "
|
50
|
+
"customers": {
|
51
|
+
"description": "Customers of a hardware store",
|
51
52
|
"columns": {
|
52
|
-
"
|
53
|
-
"name": {"prompt": "first name and last name of the
|
53
|
+
"customer_id": {"prompt": "the unique id of the customer", "dtype": "integer"},
|
54
|
+
"name": {"prompt": "first name and last name of the customer", "dtype": "string"},
|
55
|
+
},
|
56
|
+
"primary_key": "customer_id",
|
57
|
+
},
|
58
|
+
"orders": {
|
59
|
+
"description": "Orders of a Customer",
|
60
|
+
"columns": {
|
61
|
+
"customer_id": {"prompt": "the customer id for that order", "dtype": "integer"},
|
62
|
+
"order_id": {"prompt": "the unique id of the order", "dtype": "string"},
|
63
|
+
"text": {"prompt": "order text description", "dtype": "string"},
|
64
|
+
"amount": {"prompt": "order amount in USD", "dtype": "float"},
|
54
65
|
},
|
55
|
-
"primary_key": "
|
66
|
+
"primary_key": "order_id",
|
67
|
+
"foreign_keys": [
|
68
|
+
{
|
69
|
+
"column": "customer_id",
|
70
|
+
"referenced_table": "customers",
|
71
|
+
"description": "each customer has anywhere between 1 and 3 orders",
|
72
|
+
}
|
73
|
+
],
|
56
74
|
},
|
57
|
-
"
|
58
|
-
"description": "
|
75
|
+
"items": {
|
76
|
+
"description": "Items in an Order",
|
59
77
|
"columns": {
|
60
|
-
"
|
61
|
-
"
|
62
|
-
"
|
63
|
-
"
|
78
|
+
"item_id": {"prompt": "the unique id of the item", "dtype": "string"},
|
79
|
+
"order_id": {"prompt": "the order id for that item", "dtype": "string"},
|
80
|
+
"name": {"prompt": "the name of the item", "dtype": "string"},
|
81
|
+
"price": {"prompt": "the price of the item in USD", "dtype": "float"},
|
64
82
|
},
|
65
83
|
"foreign_keys": [
|
66
84
|
{
|
67
|
-
"column": "
|
68
|
-
"referenced_table": "
|
69
|
-
"description": "each
|
85
|
+
"column": "order_id",
|
86
|
+
"referenced_table": "orders",
|
87
|
+
"description": "each order has between 2 and 5 items",
|
70
88
|
}
|
71
89
|
],
|
72
90
|
},
|
73
91
|
}
|
74
|
-
data = mock.sample(tables=tables, sample_size=
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
print(
|
92
|
+
data = mock.sample(tables=tables, sample_size=2, model="openai/gpt-4.1")
|
93
|
+
df_customers = data["customers"]
|
94
|
+
df_orders = data["orders"]
|
95
|
+
df_items = data["items"]
|
96
|
+
print(df_customers)
|
97
|
+
print(df_orders)
|
98
|
+
print(df_items)
|
79
99
|
```
|
@@ -18,10 +18,11 @@ import json
|
|
18
18
|
from collections import deque
|
19
19
|
from collections.abc import Generator
|
20
20
|
from enum import Enum
|
21
|
+
from typing import Any, Literal, Type
|
21
22
|
|
22
23
|
import litellm
|
23
24
|
import pandas as pd
|
24
|
-
from pydantic import BaseModel, Field, RootModel, create_model, field_validator
|
25
|
+
from pydantic import BaseModel, Field, RootModel, create_model, field_validator, model_validator
|
25
26
|
from tqdm import tqdm
|
26
27
|
|
27
28
|
SYSTEM_PROMPT = f"""
|
@@ -88,6 +89,31 @@ class MockConfig(RootModel[dict[str, "TableConfig"]]):
|
|
88
89
|
|
89
90
|
return tables
|
90
91
|
|
92
|
+
@model_validator(mode="after")
|
93
|
+
def validate_no_circular_dependencies(self) -> MockConfig:
|
94
|
+
child_to_parents = {}
|
95
|
+
for table_name, table_config in self.root.items():
|
96
|
+
child_to_parents[table_name] = [fk.referenced_table for fk in table_config.foreign_keys]
|
97
|
+
visited = set()
|
98
|
+
|
99
|
+
def detect_cycle(table_name: str, path: list[str]) -> None:
|
100
|
+
if table_name in path:
|
101
|
+
cycle_start = path.index(table_name)
|
102
|
+
cycle = path[cycle_start:] + [table_name]
|
103
|
+
raise ValueError(f"Circular dependency detected: {' -> '.join(cycle)}")
|
104
|
+
if table_name in visited:
|
105
|
+
return
|
106
|
+
visited.add(table_name)
|
107
|
+
path.append(table_name)
|
108
|
+
for parent in child_to_parents[table_name]:
|
109
|
+
detect_cycle(parent, path)
|
110
|
+
path.pop()
|
111
|
+
|
112
|
+
for table_name in child_to_parents:
|
113
|
+
detect_cycle(table_name, [])
|
114
|
+
|
115
|
+
return self
|
116
|
+
|
91
117
|
|
92
118
|
class TableConfig(BaseModel):
|
93
119
|
description: str = ""
|
@@ -97,14 +123,59 @@ class TableConfig(BaseModel):
|
|
97
123
|
|
98
124
|
|
99
125
|
class ColumnConfig(BaseModel):
|
100
|
-
prompt: str
|
126
|
+
prompt: str = ""
|
101
127
|
dtype: DType
|
128
|
+
values: list[Any] = Field(default_factory=list)
|
129
|
+
|
130
|
+
@model_validator(mode="before")
|
131
|
+
def set_default_dtype(cls, data):
|
132
|
+
if isinstance(data, dict):
|
133
|
+
if "dtype" not in data:
|
134
|
+
if data.get("values"):
|
135
|
+
data["dtype"] = DType.CATEGORY
|
136
|
+
else:
|
137
|
+
data["dtype"] = DType.STRING
|
138
|
+
return data
|
139
|
+
|
140
|
+
@model_validator(mode="after")
|
141
|
+
def ensure_values_are_unique(self) -> ColumnConfig:
|
142
|
+
if self.values:
|
143
|
+
if len(self.values) != len(set(self.values)):
|
144
|
+
raise ValueError("Values must be unique")
|
145
|
+
return self
|
146
|
+
|
147
|
+
@model_validator(mode="after")
|
148
|
+
def ensure_values_are_provided_for_category_dtype(self) -> ColumnConfig:
|
149
|
+
if self.dtype == DType.CATEGORY and not self.values:
|
150
|
+
raise ValueError("At least one value must be provided when dtype is 'category'")
|
151
|
+
return self
|
152
|
+
|
153
|
+
@model_validator(mode="after")
|
154
|
+
def harmonize_values_with_dtypes(self) -> ColumnConfig:
|
155
|
+
if self.values:
|
156
|
+
cast_fn, convertible_to = {
|
157
|
+
DType.INTEGER: (int, "integers"),
|
158
|
+
DType.FLOAT: (float, "floats"),
|
159
|
+
DType.STRING: (str, "strings"),
|
160
|
+
DType.CATEGORY: (lambda c: c, "categories"),
|
161
|
+
DType.BOOLEAN: (bool, "booleans"),
|
162
|
+
DType.DATE: (str, "strings"),
|
163
|
+
DType.DATETIME: (str, "strings"),
|
164
|
+
}[self.dtype]
|
165
|
+
try:
|
166
|
+
self.values = [cast_fn(c) for c in self.values]
|
167
|
+
except ValueError:
|
168
|
+
raise ValueError(
|
169
|
+
f"All values must be convertible to {convertible_to} when dtype is '{self.dtype.value}'"
|
170
|
+
)
|
171
|
+
return self
|
102
172
|
|
103
173
|
|
104
174
|
class DType(str, Enum):
|
105
175
|
INTEGER = "integer"
|
106
176
|
FLOAT = "float"
|
107
177
|
STRING = "string"
|
178
|
+
CATEGORY = "category"
|
108
179
|
BOOLEAN = "boolean"
|
109
180
|
DATE = "date"
|
110
181
|
DATETIME = "datetime"
|
@@ -188,7 +259,7 @@ def _create_table_prompt(
|
|
188
259
|
# add previous rows as context to help the LLM generate consistent data
|
189
260
|
if previous_rows:
|
190
261
|
prompt += f"\n## Previous {len(previous_rows)} Rows:\n\n"
|
191
|
-
prompt += json.dumps(previous_rows, indent=2)
|
262
|
+
prompt += f"{json.dumps(previous_rows, indent=2)}\n\n"
|
192
263
|
|
193
264
|
# add context table name, primary key and data
|
194
265
|
if context_data is not None:
|
@@ -206,12 +277,14 @@ def _create_table_prompt(
|
|
206
277
|
prompt += f"Generate {batch_size} rows for the `{table_name}` table.\n\n"
|
207
278
|
else:
|
208
279
|
prompt += (
|
209
|
-
f"Generate
|
210
|
-
f"The Foreign Key column may only contain values from Context Table Data
|
280
|
+
f"Generate data for the `{table_name}` table. "
|
281
|
+
f"The Foreign Key column may only contain values from Context Table Data. "
|
282
|
+
f"Pay attention to description of the Foreign Key column to understand the relationship.\n\n"
|
211
283
|
)
|
212
284
|
if previous_rows:
|
213
285
|
prompt += (
|
214
286
|
"Generate new rows that maintain consistency with the previous rows where appropriate. "
|
287
|
+
"Don't copy previous rows in the output. "
|
215
288
|
"Don't pay attention to the number of previous rows; there might have been more generated than provided.\n\n"
|
216
289
|
)
|
217
290
|
prompt += f"Do not use code to generate the data.\n\n"
|
@@ -234,19 +307,23 @@ def _create_table_rows_generator(
|
|
234
307
|
llm_config: LLMConfig,
|
235
308
|
) -> Generator[dict]:
|
236
309
|
def create_table_response_format(columns: dict[str, ColumnConfig]) -> BaseModel:
|
237
|
-
|
238
|
-
DType.
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
310
|
+
def create_annotation(column_config: ColumnConfig) -> Type:
|
311
|
+
if column_config.values or column_config.dtype is DType.CATEGORY:
|
312
|
+
return Literal[tuple(column_config.values)]
|
313
|
+
return {
|
314
|
+
DType.INTEGER: int,
|
315
|
+
DType.FLOAT: float,
|
316
|
+
DType.STRING: str,
|
317
|
+
DType.BOOLEAN: bool,
|
318
|
+
# response_format has limited support for JSON Schema features
|
319
|
+
# thus we represent dates and datetimes as strings
|
320
|
+
DType.DATE: str,
|
321
|
+
DType.DATETIME: str,
|
322
|
+
}[column_config.dtype]
|
323
|
+
|
247
324
|
fields = {}
|
248
325
|
for column_name, column_config in columns.items():
|
249
|
-
annotation =
|
326
|
+
annotation = create_annotation(column_config)
|
250
327
|
fields[column_name] = (annotation, Field(...))
|
251
328
|
TableRow = create_model("TableRow", **fields)
|
252
329
|
TableRows = create_model("TableRows", rows=(list[TableRow], ...))
|
@@ -351,16 +428,14 @@ def _convert_table_rows_generator_to_df(
|
|
351
428
|
def align_df_dtypes_with_mock_dtypes(df: pd.DataFrame, columns: dict[str, ColumnConfig]) -> pd.DataFrame:
|
352
429
|
for column_name, column_config in columns.items():
|
353
430
|
if column_config.dtype in [DType.DATE, DType.DATETIME]:
|
354
|
-
# datetime.date, datetime.datetime -> datetime64[ns] / datetime64[ns, tz]
|
355
431
|
df[column_name] = pd.to_datetime(df[column_name], errors="coerce")
|
356
432
|
elif column_config.dtype in [DType.INTEGER, DType.FLOAT]:
|
357
|
-
# int -> int64[pyarrow], float -> double[pyarrow]
|
358
433
|
df[column_name] = pd.to_numeric(df[column_name], errors="coerce", dtype_backend="pyarrow")
|
359
434
|
elif column_config.dtype is DType.BOOLEAN:
|
360
|
-
# bool -> bool
|
361
435
|
df[column_name] = df[column_name].astype(bool)
|
436
|
+
elif column_config.dtype is DType.CATEGORY:
|
437
|
+
df[column_name] = pd.Categorical(df[column_name], categories=column_config.values)
|
362
438
|
else:
|
363
|
-
# other -> string[pyarrow]
|
364
439
|
df[column_name] = df[column_name].astype("string[pyarrow]")
|
365
440
|
return df
|
366
441
|
|
@@ -378,6 +453,44 @@ def _harmonize_sample_size(sample_size: int | dict[str, int], config: MockConfig
|
|
378
453
|
return sample_size
|
379
454
|
|
380
455
|
|
456
|
+
def _build_dependency_graph(config: MockConfig) -> tuple[dict[str, list[str]], dict[str, list[str]], list[str]]:
|
457
|
+
child_to_parents = {}
|
458
|
+
parent_to_children = {}
|
459
|
+
|
460
|
+
for table_name in config.root:
|
461
|
+
child_to_parents[table_name] = []
|
462
|
+
parent_to_children[table_name] = []
|
463
|
+
|
464
|
+
for table_name, table_config in config.root.items():
|
465
|
+
if table_config.foreign_keys:
|
466
|
+
for fk in table_config.foreign_keys:
|
467
|
+
referenced_table = fk.referenced_table
|
468
|
+
child_to_parents[table_name].append(referenced_table)
|
469
|
+
parent_to_children[referenced_table].append(table_name)
|
470
|
+
|
471
|
+
subject_tables = [table_name for table_name, deps in child_to_parents.items() if not deps]
|
472
|
+
return child_to_parents, parent_to_children, subject_tables
|
473
|
+
|
474
|
+
|
475
|
+
def _build_execution_plan(parent_to_children: dict[str, list[str]], subject_tables: list[str]) -> list[str]:
|
476
|
+
execution_plan = []
|
477
|
+
bfs_queue = list(subject_tables)
|
478
|
+
processed = set()
|
479
|
+
|
480
|
+
while bfs_queue:
|
481
|
+
table_name = bfs_queue.pop(0)
|
482
|
+
if table_name in processed:
|
483
|
+
continue
|
484
|
+
|
485
|
+
execution_plan.append(table_name)
|
486
|
+
processed.add(table_name)
|
487
|
+
|
488
|
+
for child in parent_to_children[table_name]:
|
489
|
+
if child not in bfs_queue and child not in processed:
|
490
|
+
bfs_queue.append(child)
|
491
|
+
return execution_plan
|
492
|
+
|
493
|
+
|
381
494
|
def sample(
|
382
495
|
*,
|
383
496
|
tables: dict[str, dict],
|
@@ -404,6 +517,8 @@ def sample(
|
|
404
517
|
- `openai/gpt-4.1`
|
405
518
|
- `gemini/gemini-2.0-flash`
|
406
519
|
- `gemini/gemini-2.5-flash-preview-04-17`
|
520
|
+
- `groq/llama-3.3-70b-versatile`
|
521
|
+
- `anthropic/claude-3-7-sonnet-latest`
|
407
522
|
See https://docs.litellm.ai/docs/providers/ for more options.
|
408
523
|
api_key (str | None): The API key to use for the LLM. If not provided, LiteLLM will take it from the environment variables.
|
409
524
|
temperature (float): The temperature to use for the LLM. Default is 1.0.
|
@@ -423,12 +538,13 @@ def sample(
|
|
423
538
|
"columns": {
|
424
539
|
"nationality": {"prompt": "2-letter code for the nationality", "dtype": "string"},
|
425
540
|
"name": {"prompt": "first name and last name of the guest", "dtype": "string"},
|
426
|
-
"gender": {"
|
541
|
+
"gender": {"dtype": "category", "values": ["male", "female"]},
|
427
542
|
"age": {"prompt": "age in years; min: 18, max: 80; avg: 25", "dtype": "integer"},
|
428
543
|
"date_of_birth": {"prompt": "date of birth", "dtype": "date"},
|
429
544
|
"checkin_time": {"prompt": "the check in timestamp of the guest; may 2025", "dtype": "datetime"},
|
430
545
|
"is_vip": {"prompt": "is the guest a VIP", "dtype": "boolean"},
|
431
546
|
"price_per_night": {"prompt": "price paid per night, in EUR", "dtype": "float"},
|
547
|
+
"room_number": {"prompt": "room number", "dtype": "integer", "values": [101, 102, 103, 201, 202, 203, 204]}
|
432
548
|
},
|
433
549
|
}
|
434
550
|
}
|
@@ -440,34 +556,52 @@ def sample(
|
|
440
556
|
from mostlyai import mock
|
441
557
|
|
442
558
|
tables = {
|
443
|
-
"
|
444
|
-
"description": "
|
559
|
+
"customers": {
|
560
|
+
"description": "Customers of a hardware store",
|
445
561
|
"columns": {
|
446
|
-
"
|
447
|
-
"name": {"prompt": "first name and last name of the
|
562
|
+
"customer_id": {"prompt": "the unique id of the customer", "dtype": "integer"},
|
563
|
+
"name": {"prompt": "first name and last name of the customer", "dtype": "string"},
|
448
564
|
},
|
449
|
-
"primary_key": "
|
565
|
+
"primary_key": "customer_id",
|
450
566
|
},
|
451
|
-
"
|
452
|
-
"description": "
|
567
|
+
"orders": {
|
568
|
+
"description": "Orders of a Customer",
|
453
569
|
"columns": {
|
454
|
-
"
|
455
|
-
"
|
456
|
-
"text": {"prompt": "
|
457
|
-
"amount": {"prompt": "
|
570
|
+
"customer_id": {"prompt": "the customer id for that order", "dtype": "integer"},
|
571
|
+
"order_id": {"prompt": "the unique id of the order", "dtype": "string"},
|
572
|
+
"text": {"prompt": "order text description", "dtype": "string"},
|
573
|
+
"amount": {"prompt": "order amount in USD", "dtype": "float"},
|
458
574
|
},
|
575
|
+
"primary_key": "order_id",
|
459
576
|
"foreign_keys": [
|
460
577
|
{
|
461
|
-
"column": "
|
462
|
-
"referenced_table": "
|
463
|
-
"description": "each
|
578
|
+
"column": "customer_id",
|
579
|
+
"referenced_table": "customers",
|
580
|
+
"description": "each customer has anywhere between 1 and 3 orders",
|
581
|
+
}
|
582
|
+
],
|
583
|
+
},
|
584
|
+
"items": {
|
585
|
+
"description": "Items in an Order",
|
586
|
+
"columns": {
|
587
|
+
"item_id": {"prompt": "the unique id of the item", "dtype": "string"},
|
588
|
+
"order_id": {"prompt": "the order id for that item", "dtype": "string"},
|
589
|
+
"name": {"prompt": "the name of the item", "dtype": "string"},
|
590
|
+
"price": {"prompt": "the price of the item in USD", "dtype": "float"},
|
591
|
+
},
|
592
|
+
"foreign_keys": [
|
593
|
+
{
|
594
|
+
"column": "order_id",
|
595
|
+
"referenced_table": "orders",
|
596
|
+
"description": "each order has between 2 and 5 items",
|
464
597
|
}
|
465
598
|
],
|
466
599
|
},
|
467
600
|
}
|
468
|
-
data = mock.sample(tables=tables, sample_size=
|
469
|
-
|
470
|
-
|
601
|
+
data = mock.sample(tables=tables, sample_size=2, model="openai/gpt-4.1")
|
602
|
+
df_customers = data["customers"]
|
603
|
+
df_orders = data["orders"]
|
604
|
+
df_items = data["items"]
|
471
605
|
```
|
472
606
|
"""
|
473
607
|
|
@@ -475,9 +609,15 @@ def sample(
|
|
475
609
|
|
476
610
|
sample_size = _harmonize_sample_size(sample_size, config)
|
477
611
|
primary_keys = {table_name: table_config.primary_key for table_name, table_config in config.root.items()}
|
478
|
-
|
479
|
-
|
480
|
-
|
612
|
+
|
613
|
+
child_to_parents, parent_to_children, subject_tables = _build_dependency_graph(config)
|
614
|
+
execution_plan: list[str] = _build_execution_plan(parent_to_children, subject_tables)
|
615
|
+
|
616
|
+
results: dict[str, pd.DataFrame] = {}
|
617
|
+
|
618
|
+
for table_name in execution_plan:
|
619
|
+
table_config = config.root[table_name]
|
620
|
+
if not child_to_parents[table_name]:
|
481
621
|
# subject table
|
482
622
|
df = _sample_table(
|
483
623
|
table_name=table_name,
|
@@ -491,22 +631,21 @@ def sample(
|
|
491
631
|
previous_rows_size=5,
|
492
632
|
llm_config=LLMConfig(model=model, api_key=api_key),
|
493
633
|
)
|
494
|
-
|
495
|
-
#
|
634
|
+
else:
|
635
|
+
# sequencial table
|
636
|
+
referenced_table = table_config.foreign_keys[0].referenced_table
|
496
637
|
df = _sample_table(
|
497
638
|
table_name=table_name,
|
498
639
|
table_config=table_config,
|
499
640
|
primary_keys=primary_keys,
|
500
641
|
sample_size=None,
|
501
|
-
context_data=
|
642
|
+
context_data=results[referenced_table],
|
502
643
|
temperature=temperature,
|
503
644
|
top_p=top_p,
|
504
645
|
batch_size=1, # generate one sequence at a time
|
505
646
|
previous_rows_size=5,
|
506
647
|
llm_config=LLMConfig(model=model, api_key=api_key),
|
507
648
|
)
|
508
|
-
|
509
|
-
raise RuntimeError("Only 1 or 2 table setups are supported for now")
|
510
|
-
dfs[table_name] = df
|
649
|
+
results[table_name] = df
|
511
650
|
|
512
|
-
return
|
651
|
+
return results if len(results) > 1 else next(iter(results.values()))
|
@@ -1,13 +0,0 @@
|
|
1
|
-
Copyright 2025 MOSTLY AI
|
2
|
-
|
3
|
-
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
you may not use this file except in compliance with the License.
|
5
|
-
You may obtain a copy of the License at
|
6
|
-
|
7
|
-
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
|
9
|
-
Unless required by applicable law or agreed to in writing, software
|
10
|
-
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
See the License for the specific language governing permissions and
|
13
|
-
limitations under the License.
|
File without changes
|
File without changes
|