mostlyai-mock 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mostlyai/mock/__init__.py CHANGED
@@ -15,4 +15,4 @@
15
15
  from mostlyai.mock.core import sample
16
16
 
17
17
  __all__ = ["sample"]
18
- __version__ = "0.1.7" # Do not set this manually. Use poetry version [params].
18
+ __version__ = "0.1.9" # Do not set this manually. Use poetry version [params].
mostlyai/mock/core.py CHANGED
@@ -14,42 +14,29 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
- import itertools
17
+ import asyncio
18
+ import concurrent.futures
18
19
  import json
20
+ import math
19
21
  from collections import deque
20
- from collections.abc import Generator
22
+ from collections.abc import AsyncGenerator
21
23
  from enum import Enum
24
+ from io import StringIO
22
25
  from typing import Any, Literal
23
26
 
27
+ import dateutil.parser
24
28
  import litellm
25
29
  import pandas as pd
26
30
  import tenacity
27
31
  from pydantic import BaseModel, Field, RootModel, create_model, field_validator, model_validator
28
- from tqdm import tqdm
32
+ from tqdm.asyncio import tqdm
29
33
 
30
34
  litellm.suppress_debug_info = True
31
35
 
32
- SYSTEM_PROMPT = """
33
- You are a specialized mock data generator designed to create highly realistic, contextually appropriate data based on schema definitions.
34
36
 
35
- Your task is to:
36
-
37
- 1. Generate data that strictly adheres to the provided schema constraints (data types, ranges, formats)
38
- 2. Ensure logical consistency across related tables and foreign key relationships
39
- 3. Create contextually appropriate values that reflect real-world patterns and distributions
40
- 4. Produce diverse, non-repetitive data that avoids obvious patterns
41
- 5. Respect uniqueness constraints and other data integrity rules
42
- 6. When enriching existing data, ensure that new values are consistent with existing values
43
- 7. Return well-formatted JSON output that can be directly parsed
44
- 8. Don't use markdown formatting
45
-
46
- For numeric fields, generate realistic distributions rather than random values. For text fields, create contextually \
47
- appropriate content. For dates and timestamps, ensure logical chronology. Always maintain referential integrity \
48
- across tables.
49
-
50
- When enriching existing data, carefully analyze the patterns and relationships in the existing columns \
51
- to generate compatible and realistic values for the missing columns.
52
- """
37
+ class LLMOutputFormat(str, Enum):
38
+ JSON = "JSON"
39
+ CSV = "CSV"
53
40
 
54
41
 
55
42
  class LLMConfig(BaseModel):
@@ -162,6 +149,12 @@ class ColumnConfig(BaseModel):
162
149
  raise ValueError("At least one value must be provided when dtype is 'category'")
163
150
  return self
164
151
 
152
+ @model_validator(mode="after")
153
+ def override_values_for_boolean_dtype(self) -> ColumnConfig:
154
+ if self.dtype == DType.BOOLEAN:
155
+ self.values = [True, False]
156
+ return self
157
+
165
158
  @model_validator(mode="after")
166
159
  def harmonize_values_with_dtypes(self) -> ColumnConfig:
167
160
  if self.values:
@@ -199,18 +192,18 @@ class ForeignKeyConfig(BaseModel):
199
192
  prompt: str | None = None
200
193
 
201
194
 
202
- def _sample_table(
195
+ async def _sample_table(
203
196
  *,
204
197
  name: str,
205
198
  prompt: str,
206
199
  columns: dict[str, ColumnConfig],
207
- foreign_keys: list[ForeignKeyConfig] | None,
208
- primary_keys: dict[str, str] | None,
200
+ foreign_keys: list[ForeignKeyConfig],
201
+ primary_keys: dict[str, str],
209
202
  data: dict[str, pd.DataFrame],
210
203
  sample_size: int,
211
- batch_size: int,
212
204
  previous_rows_size: int,
213
205
  non_context_size: int | None,
206
+ n_workers: int,
214
207
  llm_config: LLMConfig,
215
208
  ) -> pd.DataFrame:
216
209
  table_rows_generator = _create_table_rows_generator(
@@ -221,28 +214,62 @@ def _sample_table(
221
214
  foreign_keys=foreign_keys,
222
215
  data=data,
223
216
  sample_size=sample_size,
224
- batch_size=batch_size,
225
217
  previous_rows_size=previous_rows_size,
226
218
  non_context_size=non_context_size,
219
+ n_workers=n_workers,
227
220
  llm_config=llm_config,
228
221
  )
229
222
  table_rows_generator = tqdm(table_rows_generator, desc=f"Generating rows for table `{name}`".ljust(45))
230
- table_df = _convert_table_rows_generator_to_df(table_rows_generator=table_rows_generator, columns=columns)
223
+ table_df = await _convert_table_rows_generator_to_df(table_rows_generator=table_rows_generator, columns=columns)
231
224
  return table_df
232
225
 
233
226
 
227
+ def _sample_table_sync(*args, **kwargs) -> pd.DataFrame:
228
+ loop = asyncio.new_event_loop()
229
+ asyncio.set_event_loop(loop)
230
+ try:
231
+ return loop.run_until_complete(_sample_table(*args, **kwargs))
232
+ finally:
233
+ loop.close()
234
+
235
+
236
+ def _create_system_prompt(llm_output_format: LLMOutputFormat) -> str:
237
+ return f"""
238
+ You are a specialized data generator designed to create highly realistic, contextually appropriate data based on schema definitions.
239
+
240
+ Your task is to:
241
+
242
+ 1. Generate data that strictly adheres to the provided schema constraints (data types, ranges, formats)
243
+ 2. Ensure logical consistency across related tables and foreign key relationships
244
+ 3. Create contextually appropriate values that reflect real-world patterns and distributions
245
+ 4. Produce diverse, non-repetitive data that avoids obvious patterns
246
+ 5. Respect uniqueness constraints and other data integrity rules
247
+ 6. When enriching existing data, ensure that new values are consistent with existing values
248
+ 7. Return well-formatted {llm_output_format.value} output that can be directly parsed
249
+ 8. Don't use markdown formatting
250
+
251
+ For numeric fields, generate realistic distributions rather than random values. For text fields, create contextually \
252
+ appropriate content. For dates and timestamps, ensure logical chronology. Always maintain referential integrity \
253
+ across tables.
254
+
255
+ When enriching existing data, carefully analyze the patterns and relationships in the existing columns \
256
+ to generate compatible and realistic values for the missing columns.
257
+ """
258
+
259
+
234
260
  def _create_table_prompt(
235
261
  *,
236
262
  name: str,
237
263
  prompt: str,
238
264
  columns: dict[str, ColumnConfig],
239
- primary_keys: dict[str, str] | None,
265
+ primary_keys: dict[str, str],
240
266
  batch_size: int | None,
241
- foreign_keys: list[ForeignKeyConfig] | None,
267
+ foreign_keys: list[ForeignKeyConfig],
242
268
  existing_data: pd.DataFrame | None,
243
269
  context_data: pd.DataFrame | None,
244
270
  non_context_data: dict[str, pd.DataFrame] | None,
245
271
  previous_rows: list[dict] | None,
272
+ llm_output_format: LLMOutputFormat,
246
273
  ) -> str:
247
274
  # add table prompt
248
275
  prompt = f"# {prompt}\n\n"
@@ -345,7 +372,7 @@ def _create_table_prompt(
345
372
 
346
373
  prompt += f"{verb.capitalize()} data for the Target Table `{name}`.\n\n"
347
374
  if n_rows is not None:
348
- prompt += f"Number of rows to {verb}: `{n_rows}`.\n\n"
375
+ prompt += f"Number of data rows to {verb}: `{n_rows}`.\n\n"
349
376
 
350
377
  if has_context_table_section:
351
378
  assert foreign_keys
@@ -387,131 +414,341 @@ def _create_table_prompt(
387
414
 
388
415
  prompt += f"Do not use code to {verb} the data.\n\n"
389
416
 
390
- prompt += "Return data as a JSON string."
391
- prompt += " The JSON string should have 'rows' key at the top level. The value of 'rows' key should be a list of JSON objects."
392
- prompt += " Each JSON object should have column names as keys and values as column values."
417
+ prompt += f"Return data as a {llm_output_format.value} string."
418
+ if llm_output_format == LLMOutputFormat.JSON:
419
+ prompt += " The JSON string should have 'rows' key at the top level."
420
+ prompt += " The value of 'rows' key should be a list of JSON objects."
421
+ prompt += " Each JSON object should have column names as keys and values as column values."
422
+ else: # llm_output_format == LLMOutputFormat.CSV
423
+ prompt += " The CSV string should have a header row with column names."
424
+ prompt += " The CSV string should have a data row for each row to be generated."
425
+ prompt += " The CSV string should have a newline character at the end of each row."
426
+ prompt += " Each value in the CSV string should be enclosed in double quotes."
427
+
393
428
  if existing_data is not None:
394
- prompt += (
395
- f" Only include the following columns in the JSON string: {list(columns.keys() - existing_data.columns)}."
396
- )
429
+ prompt += f" Only include the following columns in the {llm_output_format.value} string: {list(columns.keys() - existing_data.columns)}."
430
+
431
+ if llm_output_format == LLMOutputFormat.CSV and batch_size > 10:
432
+ prompt += " Additionally, add column called `_ROW_IDX` that is a counter from 1 to the number of rows generated so far within current batch."
433
+
397
434
  prompt += "\n"
398
435
  return prompt
399
436
 
400
437
 
401
- def _create_table_rows_generator(
438
+ def _completion_with_retries(*args, **kwargs):
439
+ n_attempts = 3
440
+
441
+ def print_on_retry(_):
442
+ print(" * Calling LLM again... * ", end="", flush=True)
443
+
444
+ # try up to 3 times, print a message to the user on each retry
445
+ retryer = tenacity.AsyncRetrying(
446
+ stop=tenacity.stop_after_attempt(n_attempts), reraise=True, before_sleep=print_on_retry
447
+ )
448
+ return retryer(litellm.acompletion, *args, **kwargs)
449
+
450
+
451
+ async def _yield_rows_from_json_chunks_stream(response: litellm.CustomStreamWrapper) -> AsyncGenerator[dict]:
452
+ def buffer_to_row(buffer: list[str]) -> dict:
453
+ return json.loads("".join(buffer))
454
+
455
+ # starting with dirty buffer is to handle the `{"rows": []}` case
456
+ buffer = list("garbage")
457
+ rows_json_started = False
458
+ in_row_json = False
459
+ async for chunk in response:
460
+ delta = chunk.choices[0].delta.content
461
+ if delta is None:
462
+ continue
463
+ for char in delta:
464
+ buffer.append(char)
465
+ if char == "{" and not rows_json_started:
466
+ # {"rows": [{"name": "Jo\}h\{n"}]}
467
+ # * <- start of rows json stream
468
+ rows_json_started = True
469
+ elif char == "{" and not in_row_json:
470
+ # {"rows": [{"name": "Jo\}h\{n"}]}
471
+ # * <- start of single row json stream
472
+ buffer = list("{")
473
+ in_row_json = True
474
+ elif char == "}":
475
+ # {"rows": [{"name": "Jo\}h\{n"}]}
476
+ # * * * <- any of these
477
+ try:
478
+ row = buffer_to_row(buffer)
479
+ except Exception:
480
+ # in case of any error, silently drop the row
481
+ continue
482
+ finally:
483
+ buffer = list()
484
+ in_row_json = False
485
+ yield row
486
+
487
+
488
+ async def _yield_rows_from_csv_chunks_stream(response: litellm.CustomStreamWrapper) -> AsyncGenerator[dict]:
489
+ def buffer_to_row(buffer: list[str]) -> list[str]:
490
+ return pd.read_csv(StringIO("".join(buffer)), header=None).astype(str).iloc[0].to_list()
491
+
492
+ buffer = list()
493
+ header = None
494
+ async for chunk in response:
495
+ delta = chunk.choices[0].delta.content
496
+ if delta is None:
497
+ continue
498
+ for char in delta:
499
+ buffer.append(char)
500
+ if char == "\n":
501
+ try:
502
+ row = buffer_to_row(buffer)
503
+ except Exception:
504
+ # in case of any error, silently drop the row
505
+ continue
506
+ finally:
507
+ buffer = list()
508
+ if header is None:
509
+ # column1,column2,column3\n
510
+ # ** <- end of header row
511
+ header = row
512
+ else:
513
+ # value_1,value_2,value_3\n
514
+ # ** <- end of data row
515
+ yield dict(zip(header, row))
516
+ if buffer:
517
+ # last row might not finish with a newline, in which case the buffer would not be empty here
518
+ try:
519
+ last_row = buffer_to_row(buffer)
520
+ yield dict(zip(header, last_row))
521
+ except Exception:
522
+ # in case of any error, silently drop the row
523
+ pass
524
+
525
+
526
+ def _create_structured_output_schema(
527
+ columns: dict[str, ColumnConfig], existing_data: pd.DataFrame | None
528
+ ) -> type[BaseModel]:
529
+ def create_annotation(column_config: ColumnConfig) -> type:
530
+ if column_config.values or column_config.dtype is DType.CATEGORY:
531
+ return Literal[tuple(column_config.values)]
532
+ return {
533
+ DType.INTEGER: int | None,
534
+ DType.FLOAT: float | None,
535
+ DType.STRING: str | None,
536
+ DType.BOOLEAN: bool | None,
537
+ # response_format has limited support for JSON Schema features
538
+ # thus we represent dates and datetimes as strings
539
+ DType.DATE: str | None,
540
+ DType.DATETIME: str | None,
541
+ }[column_config.dtype]
542
+
543
+ fields = {}
544
+ for column_name, column_config in columns.items():
545
+ if existing_data is not None and column_name in existing_data.columns:
546
+ continue # skip columns that already exist in existing data
547
+ annotation = create_annotation(column_config)
548
+ fields[column_name] = (annotation, Field(...))
549
+ TableRow = create_model("TableRow", **fields)
550
+ TableRows = create_model("TableRows", rows=(list[TableRow], ...))
551
+ return TableRows
552
+
553
+
554
+ async def _worker(
555
+ *,
556
+ name: str,
557
+ prompt: str,
558
+ columns: dict[str, ColumnConfig],
559
+ foreign_keys: list[ForeignKeyConfig],
560
+ primary_keys: dict[str, str],
561
+ previous_rows: deque[dict],
562
+ batch_queue: asyncio.Queue,
563
+ result_queue: asyncio.Queue,
564
+ retry_queue: asyncio.Queue,
565
+ n_workers: int,
566
+ llm_output_format: LLMOutputFormat,
567
+ llm_config: LLMConfig,
568
+ ):
569
+ try:
570
+ while True:
571
+ do_repeat_task = False
572
+
573
+ # get task from the batch_queue
574
+ batch_idx, task = await batch_queue.get()
575
+ if task is None:
576
+ # no more tasks for the worker; break the loop
577
+ batch_queue.task_done()
578
+ break
579
+
580
+ # deconstruct task
581
+ batch_size = task["batch_size"]
582
+ existing_batch = task.get("existing_batch")
583
+ context_batch = task.get("context_batch")
584
+ non_context_batch = task.get("non_context_batch")
585
+
586
+ # resolve columns to generate
587
+ generated_columns = set(columns.keys())
588
+ if existing_batch is not None:
589
+ generated_columns = generated_columns - set(existing_batch.columns)
590
+
591
+ # construct schema for Structured Outputs (applies to JSON LLMOutputFormat only)
592
+ structured_output_schema = None
593
+ if llm_output_format == LLMOutputFormat.JSON:
594
+ structured_output_schema = _create_structured_output_schema(
595
+ columns=columns, existing_data=existing_batch
596
+ )
597
+
598
+ # construct litellm kwargs
599
+ litellm_kwargs = {
600
+ "temperature": llm_config.temperature,
601
+ "top_p": llm_config.top_p,
602
+ "model": llm_config.model,
603
+ "api_key": llm_config.api_key,
604
+ "stream": True,
605
+ }
606
+
607
+ # construct messages
608
+ system_prompt = _create_system_prompt(llm_output_format)
609
+ user_prompt = _create_table_prompt(
610
+ name=name,
611
+ prompt=prompt,
612
+ columns=columns,
613
+ primary_keys=primary_keys,
614
+ batch_size=batch_size,
615
+ foreign_keys=foreign_keys,
616
+ existing_data=existing_batch,
617
+ context_data=context_batch,
618
+ non_context_data=non_context_batch,
619
+ previous_rows=list(previous_rows),
620
+ llm_output_format=llm_output_format,
621
+ )
622
+ messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
623
+
624
+ if generated_columns:
625
+ # make LLM call
626
+ response = await _completion_with_retries(
627
+ messages=messages, response_format=structured_output_schema, **litellm_kwargs
628
+ )
629
+ yield_rows_from_chunks_stream = {
630
+ LLMOutputFormat.JSON: _yield_rows_from_json_chunks_stream,
631
+ LLMOutputFormat.CSV: _yield_rows_from_csv_chunks_stream,
632
+ }[llm_output_format]
633
+ rows_stream = yield_rows_from_chunks_stream(response)
634
+ else:
635
+ # skip roundtrip to LLM in case all columns are provided in existing data
636
+ assert existing_batch is not None
637
+
638
+ async def _yield_empty_rows(n_rows: int) -> AsyncGenerator[dict]:
639
+ for _ in range(n_rows):
640
+ yield {}
641
+
642
+ rows_stream = _yield_empty_rows(len(existing_batch))
643
+
644
+ # we first generate all rows in the batch, in order to run consistency checks
645
+ rows_generated_part = []
646
+ async for row_generated_part in rows_stream:
647
+ # remove internal columns, if exist
648
+ row_generated_part = {k: v for k, v in row_generated_part.items() if k in generated_columns}
649
+
650
+ if set(row_generated_part.keys()) != generated_columns:
651
+ if context_batch is not None or existing_batch is not None:
652
+ # in case of linked tables and data enrichment, it's critical that all rows have expected columns
653
+ print(" * Malformed row, repeating batch... * ", end="", flush=True)
654
+ do_repeat_task = True
655
+ break
656
+ else:
657
+ # in case of flat tables generation, each row is independent, therefore we only skip the invalid row
658
+ continue
659
+ rows_generated_part.append(row_generated_part)
660
+
661
+ # at least some valid rows are expected per batch, repeat the batch otherwise
662
+ if len(rows_generated_part) == 0:
663
+ print(" * No valid rows were generated, repeating batch... * ", end="", flush=True)
664
+ do_repeat_task = True
665
+
666
+ # in case of data enrichment, check that all rows were completed successfully
667
+ if existing_batch is not None and len(rows_generated_part) != len(existing_batch):
668
+ print(" * Some rows were not enriched successfully, repeating batch... * ", end="", flush=True)
669
+ do_repeat_task = True
670
+
671
+ if do_repeat_task:
672
+ # allow 10 retries across all workers before propagating the exception to the orchestrator
673
+ await retry_queue.put(1)
674
+ if retry_queue.qsize() < 10:
675
+ # put task back to the front of the batch queue
676
+ await batch_queue.put((batch_idx, task))
677
+ else:
678
+ # inform the orchestrator that max retries were reached
679
+ raise RuntimeError(
680
+ "Too many malformed batches were generated. "
681
+ "Consider changing the model in order to make generation more stable."
682
+ )
683
+
684
+ # mark current task as done
685
+ batch_queue.task_done()
686
+ continue
687
+
688
+ # collapse existing and generated parts into coherent rows
689
+ rows = []
690
+ for row_idx, row_generated_part in enumerate(rows_generated_part):
691
+ row_existing_part = existing_batch.iloc[row_idx].to_dict() if existing_batch is not None else {}
692
+ row = {**row_generated_part, **row_existing_part}
693
+ # keep columns order according to user's spec
694
+ row = {column: row[column] for column in columns.keys()}
695
+ rows.append(row)
696
+
697
+ # track previous rows for improved data consistency, in case of sequential generation
698
+ if n_workers == 1:
699
+ previous_rows.extend(rows)
700
+
701
+ # put rows to the result queue and mark current task as done
702
+ await result_queue.put((batch_idx, rows))
703
+ batch_queue.task_done()
704
+ except Exception as e:
705
+ # propagate any exception through the result queue
706
+ await result_queue.put((batch_idx, e))
707
+ raise
708
+
709
+
710
+ async def _create_table_rows_generator(
402
711
  *,
403
712
  name: str,
404
713
  prompt: str,
405
714
  columns: dict[str, ColumnConfig],
406
- foreign_keys: list[ForeignKeyConfig] | None,
407
- primary_keys: dict[str, str] | None,
715
+ foreign_keys: list[ForeignKeyConfig],
716
+ primary_keys: dict[str, str],
408
717
  data: dict[str, pd.DataFrame],
409
718
  sample_size: int,
410
- batch_size: int,
411
719
  previous_rows_size: int,
412
720
  non_context_size: int | None,
721
+ n_workers: int,
413
722
  llm_config: LLMConfig,
414
- ) -> Generator[dict]:
415
- def create_table_response_format(
416
- columns: dict[str, ColumnConfig], existing_data: pd.DataFrame | None
417
- ) -> tuple[type[BaseModel], int]:
418
- def create_annotation(column_config: ColumnConfig) -> type:
419
- if column_config.values or column_config.dtype is DType.CATEGORY:
420
- return Literal[tuple(column_config.values)]
421
- return {
422
- DType.INTEGER: int | None,
423
- DType.FLOAT: float | None,
424
- DType.STRING: str | None,
425
- DType.BOOLEAN: bool | None,
426
- # response_format has limited support for JSON Schema features
427
- # thus we represent dates and datetimes as strings
428
- DType.DATE: str | None,
429
- DType.DATETIME: str | None,
430
- }[column_config.dtype]
431
-
432
- fields = {}
433
- for column_name, column_config in columns.items():
434
- if existing_data is not None and column_name in existing_data.columns:
435
- continue # skip columns that already exist in existing data
436
- annotation = create_annotation(column_config)
437
- fields[column_name] = (annotation, Field(...))
438
- TableRow = create_model("TableRow", **fields)
439
- TableRows = create_model("TableRows", rows=(list[TableRow], ...))
440
- n_enforced_columns = len(fields)
441
- return TableRows, n_enforced_columns
442
-
443
- def yield_rows_from_json_chunks_stream(response: litellm.CustomStreamWrapper) -> Generator[dict]:
444
- # starting with dirty buffer is to handle the `{"rows": []}` case
445
- buffer = "garbage"
446
- rows_json_started = False
447
- in_row_json = False
448
- for chunk in response:
449
- delta = chunk.choices[0].delta.content
450
- if delta is None:
451
- continue
452
- for char in delta:
453
- buffer += char
454
- if char == "{" and not rows_json_started:
455
- # {"rows": [{"name": "Jo\}h\{n"}]}
456
- # * <- start of rows json stream
457
- rows_json_started = True
458
- elif char == "{" and not in_row_json:
459
- # {"rows": [{"name": "Jo\}h\{n"}]}
460
- # * <- start of single row json stream
461
- buffer = "{"
462
- in_row_json = True
463
- elif char == "}":
464
- # {"rows": [{"name": "Jo\}h\{n"}]}
465
- # * * * <- any of these
466
- try:
467
- row = json.loads(buffer)
468
- yield row
469
- buffer = ""
470
- in_row_json = False
471
- except json.JSONDecodeError:
472
- continue
473
-
474
- def batch_infinitely(data: pd.DataFrame | None) -> Generator[pd.DataFrame | None]:
475
- while True:
476
- if data is None:
477
- yield None
478
- else:
479
- for i in range(0, len(data), batch_size):
480
- yield data.iloc[i : i + batch_size]
723
+ ) -> AsyncGenerator[dict]:
724
+ batch_size = 20 # generate 20 root table rows at a time
481
725
 
482
- def completion_with_retries(*args, **kwargs):
483
- n_attempts = 3
726
+ def supports_structured_outputs(model: str) -> bool:
727
+ model = model.removeprefix("litellm_proxy/")
728
+ supported_params = litellm.get_supported_openai_params(model=model) or []
729
+ return "response_format" in supported_params and litellm.supports_response_schema(model)
484
730
 
485
- def print_on_retry(_):
486
- print(" * Trying again... * ", end="", flush=True)
731
+ llm_output_format = LLMOutputFormat.JSON if supports_structured_outputs(llm_config.model) else LLMOutputFormat.CSV
487
732
 
488
- # try up to 3 times, print a message to the user on each retry
489
- retryer = tenacity.Retrying(
490
- stop=tenacity.stop_after_attempt(n_attempts), reraise=True, before_sleep=print_on_retry
491
- )
492
- return retryer(litellm.completion, *args, **kwargs)
493
-
494
- if not llm_config.model.startswith("litellm_proxy/"):
495
- # ensure model supports response_format and json schema (this check does not work with litellm_proxy)
496
- supported_params = litellm.get_supported_openai_params(model=llm_config.model) or []
497
- assert "response_format" in supported_params and litellm.supports_response_schema(llm_config.model), (
498
- "The model does not support structured output / JSON mode."
499
- )
733
+ previous_rows = deque(maxlen=previous_rows_size)
500
734
 
501
735
  # derive data for augmentation
502
736
  existing_data: pd.DataFrame | None = None
503
737
  if name in data:
504
738
  existing_data = data[name]
505
739
  sample_size = len(existing_data)
740
+ batch_size = 10 # augment 10 root table rows at a time
506
741
 
507
742
  # derive context data (if first foreign key is present) and harmonize sample size accordingly
508
743
  context_data: pd.DataFrame | None = None
744
+ context_batches: list[pd.DataFrame] | None = None
509
745
  if foreign_keys and foreign_keys[0].referenced_table != name: # self-dependency is not considered as context
510
746
  context_table_name = foreign_keys[0].referenced_table
511
747
  assert context_table_name in data
512
748
  context_data = data[context_table_name]
513
- batch_size = 1 # generate one sequence at a time
749
+ batch_size = 1 # generate 1 sequence at a time
514
750
  sample_size = len(context_data)
751
+ context_batches = [data.iloc[i : i + batch_size] for i in range(0, len(data), batch_size)]
515
752
 
516
753
  # derive non-context data (if more than one foreign key is present)
517
754
  non_context_data: dict[str, pd.DataFrame] = {}
@@ -524,18 +761,23 @@ def _create_table_rows_generator(
524
761
  assert non_context_table_name in data
525
762
  non_context_data[non_context_table_name] = data[non_context_table_name]
526
763
 
527
- litellm_kwargs = {
528
- "temperature": llm_config.temperature,
529
- "top_p": llm_config.top_p,
530
- "model": llm_config.model,
531
- "api_key": llm_config.api_key,
532
- "stream": True,
533
- }
764
+ # calculate batch_sizes
765
+ n_total_batches = len(context_batches) if context_batches is not None else math.ceil(sample_size / batch_size)
766
+ batch_sizes = [batch_size] * n_total_batches
767
+ if context_batches is None:
768
+ # optimise the last batch size for flat tables
769
+ # +2 because LLM may not always count the rows correctly
770
+ batch_sizes[-1] = sample_size - sum(batch_sizes[:-1]) + 2
771
+
772
+ # initialize queues for async communication
773
+ batch_queue = asyncio.PriorityQueue()
774
+ result_queue = asyncio.Queue()
775
+ retry_queue = asyncio.Queue()
776
+
777
+ # populate batch queue
778
+ for batch_idx in range(n_total_batches):
779
+ context_batch = context_batches[batch_idx] if context_batches is not None else None
534
780
 
535
- batch_idx = 0
536
- yielded_sequences = 0
537
- previous_rows = deque(maxlen=previous_rows_size)
538
- for context_batch in batch_infinitely(context_data):
539
781
  # pick existing rows for current batch
540
782
  existing_batch: pd.DataFrame | None = None
541
783
  if existing_data is not None:
@@ -559,71 +801,94 @@ def _create_table_rows_generator(
559
801
  table_name: df.sample(frac=1.0).head(non_context_size) for table_name, df in non_context_data.items()
560
802
  }
561
803
 
562
- if context_batch is None:
563
- # for root tables, scale down batch size in order to prevent excessive generations
564
- remaining_rows = sample_size - yielded_sequences
565
- if batch_size >= remaining_rows:
566
- batch_size = remaining_rows + 2 # +2 because LLM may not always count the rows correctly
567
-
568
- response_format, n_enforced_columns = create_table_response_format(
569
- columns=columns, existing_data=existing_batch
570
- )
571
-
572
- llm_prompt = _create_table_prompt(
573
- name=name,
574
- prompt=prompt,
575
- columns=columns,
576
- primary_keys=primary_keys,
577
- batch_size=batch_size,
578
- foreign_keys=foreign_keys,
579
- existing_data=existing_batch,
580
- context_data=context_batch,
581
- non_context_data=non_context_batch,
582
- previous_rows=list(previous_rows),
804
+ task = {
805
+ "batch_size": batch_sizes[batch_idx],
806
+ "existing_batch": existing_batch,
807
+ "context_batch": context_batch,
808
+ "non_context_batch": non_context_batch,
809
+ }
810
+ await batch_queue.put((batch_idx, task))
811
+
812
+ # initialize workers
813
+ n_workers = min(n_total_batches, n_workers)
814
+ workers = [
815
+ asyncio.create_task(
816
+ _worker(
817
+ name=name,
818
+ prompt=prompt,
819
+ columns=columns,
820
+ foreign_keys=foreign_keys,
821
+ primary_keys=primary_keys,
822
+ previous_rows=previous_rows,
823
+ batch_queue=batch_queue,
824
+ result_queue=result_queue,
825
+ retry_queue=retry_queue,
826
+ n_workers=n_workers,
827
+ llm_output_format=llm_output_format,
828
+ llm_config=llm_config,
829
+ )
583
830
  )
584
- messages = [{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": llm_prompt}]
585
-
586
- if n_enforced_columns != 0:
587
- response = completion_with_retries(messages=messages, response_format=response_format, **litellm_kwargs)
588
- rows_stream = yield_rows_from_json_chunks_stream(response)
589
- else:
590
- # skip roundtrip to LLM in case all columns are provided in existing data
591
- rows_stream = itertools.repeat({})
592
-
593
- batch_row_idx = 0
594
- while True:
595
- try:
596
- row_generated_part = next(rows_stream)
597
- row_existing_part = existing_batch.iloc[batch_row_idx].to_dict() if existing_batch is not None else {}
598
- row = {**row_existing_part, **row_generated_part}
599
- row = {column: row[column] for column in columns.keys()} # keep columns order according to user's spec
600
- except StopIteration:
601
- break # move to next batch
602
- previous_rows.append(row)
603
- yield row
604
- if context_batch is None:
605
- # each subject row is considered a single sequence
606
- yielded_sequences += 1
607
- if yielded_sequences >= sample_size:
608
- return # move to next table
609
- batch_row_idx += 1
610
- if context_batch is not None:
611
- # for each context_batch, full sequences are generated
612
- yielded_sequences += len(context_batch)
613
- if yielded_sequences >= sample_size:
614
- return # move to next table
615
-
616
- batch_idx += 1
617
-
618
-
619
- def _convert_table_rows_generator_to_df(
620
- table_rows_generator: Generator[dict],
831
+ for _ in range(n_workers)
832
+ ]
833
+
834
+ n_completed_batches = 0
835
+ n_yielded_sequences = 0
836
+ while n_yielded_sequences < sample_size:
837
+ if n_completed_batches >= n_total_batches:
838
+ assert context_data is None, "n_total_batches is fixed for linked tables"
839
+ assert existing_data is None, "n_total_batches is fixed for data enrichment"
840
+ # LLMs may not generate exactly the number of rows requested
841
+ # in case of flat tables, we still accept such incomplete batches,
842
+ # but that means we may need to generate more batches to reach the sample size
843
+ # +2 because LLM may not always count the rows correctly
844
+ n_total_batches += 1
845
+ task = {
846
+ "batch_size": sample_size - n_yielded_sequences + 2,
847
+ }
848
+ await batch_queue.put((n_total_batches, task))
849
+ batch_idx, result = await result_queue.get()
850
+ if isinstance(result, Exception):
851
+ # if an exception is raised by any worker, cancel all workers and raise that exception
852
+ for worker in workers:
853
+ worker.cancel()
854
+ await asyncio.gather(*workers)
855
+ raise result
856
+ rows = result
857
+ for row_idx, row in enumerate(rows):
858
+ yield (batch_idx, row)
859
+ if context_batches is None or row_idx == len(rows) - 1:
860
+ # in case of flat table, each row is considered a single sequence
861
+ # in case of linked table, all rows are considered a single sequence
862
+ # NOTE: this assumes that we generate a single sequence per batch
863
+ n_yielded_sequences += 1
864
+ if n_yielded_sequences >= sample_size:
865
+ break
866
+ n_completed_batches += 1
867
+ result_queue.task_done()
868
+
869
+ # gracefully shutdown workers
870
+ await batch_queue.join()
871
+ for _ in workers:
872
+ await batch_queue.put((n_total_batches + 1, None))
873
+ await asyncio.gather(*workers)
874
+
875
+
876
+ async def _convert_table_rows_generator_to_df(
877
+ table_rows_generator: AsyncGenerator[dict],
621
878
  columns: dict[str, ColumnConfig],
622
879
  ) -> pd.DataFrame:
623
880
  def align_df_dtypes_with_mock_dtypes(df: pd.DataFrame, columns: dict[str, ColumnConfig]) -> pd.DataFrame:
881
+ df = df.copy()
624
882
  for column_name, column_config in columns.items():
625
883
  if column_config.dtype in [DType.DATE, DType.DATETIME]:
626
- df[column_name] = pd.to_datetime(df[column_name], errors="coerce")
884
+
885
+ def harmonize_datetime(x):
886
+ try:
887
+ return dateutil.parser.parse(x)
888
+ except Exception:
889
+ return pd.NaT
890
+
891
+ df[column_name] = pd.to_datetime(df[column_name].apply(harmonize_datetime), errors="coerce")
627
892
  elif column_config.dtype is DType.INTEGER:
628
893
  df[column_name] = pd.to_numeric(df[column_name], errors="coerce", downcast="integer").astype(
629
894
  "int64[pyarrow]"
@@ -631,6 +896,8 @@ def _convert_table_rows_generator_to_df(
631
896
  elif column_config.dtype is DType.FLOAT:
632
897
  df[column_name] = pd.to_numeric(df[column_name], errors="coerce").astype("double[pyarrow]")
633
898
  elif column_config.dtype is DType.BOOLEAN:
899
+ df[column_name] = df[column_name].map(lambda x: True if str(x).lower() == "true" else x)
900
+ df[column_name] = df[column_name].map(lambda x: False if str(x).lower() == "false" else x)
634
901
  df[column_name] = pd.to_numeric(df[column_name], errors="coerce").astype("boolean[pyarrow]")
635
902
  elif column_config.dtype is DType.CATEGORY:
636
903
  df[column_name] = pd.Categorical(df[column_name], categories=column_config.values)
@@ -638,7 +905,13 @@ def _convert_table_rows_generator_to_df(
638
905
  df[column_name] = df[column_name].astype("string[pyarrow]")
639
906
  return df
640
907
 
641
- df = pd.DataFrame(list(table_rows_generator))
908
+ # consume entire generator
909
+ items = [{"batch_idx": batch_idx, "row": row} async for batch_idx, row in table_rows_generator]
910
+ # sort items by batch_idx to maintain order (relevant especially for keeping the order of existing data)
911
+ items = sorted(items, key=lambda x: x["batch_idx"])
912
+ # extract rows and convert to DataFrame
913
+ rows = [item["row"] for item in items]
914
+ df = pd.DataFrame(rows)
642
915
  df = align_df_dtypes_with_mock_dtypes(df, columns)
643
916
  return df
644
917
 
@@ -743,10 +1016,11 @@ def sample(
743
1016
  api_key: str | None = None,
744
1017
  temperature: float = 1.0,
745
1018
  top_p: float = 0.95,
1019
+ n_workers: int = 10,
746
1020
  return_type: Literal["auto", "dict"] = "auto",
747
1021
  ) -> pd.DataFrame | dict[str, pd.DataFrame]:
748
1022
  """
749
- Generate mock data from scratch or enrich existing data by prompting an LLM.
1023
+ Generate synthetic data from scratch or enrich existing data with new columns.
750
1024
 
751
1025
  While faker and numpy are useful to create fake data, this utility is unique as it allows
752
1026
  the creation of coherent, realistic multi-table tabular mock data
@@ -765,7 +1039,7 @@ def sample(
765
1039
  If a table has a foreign key, the sample size is determined by the corresponding foreign key prompt. If nothing specified, a few rows per parent record are generated.
766
1040
  existing_data (dict[str, pd.DataFrame] | None): Existing data to augment. If provided, the sample_size argument is ignored.
767
1041
  Default is None.
768
- model (str): The LiteLLM chat completion model to be used. Model needs to support structured output / JSON mode.
1042
+ model (str): The LiteLLM chat completion model to be used.
769
1043
  Examples include:
770
1044
  - `openai/gpt-4.1-nano` (default; fast, and smart)
771
1045
  - `openai/gpt-4.1-mini` (slower, but smarter)
@@ -779,6 +1053,8 @@ def sample(
779
1053
  api_key (str | None): The API key to use for the LLM. If not provided, LiteLLM will take it from the environment variables.
780
1054
  temperature (float): The temperature to use for the LLM. Default is 1.0.
781
1055
  top_p (float): The top-p value to use for the LLM. Default is 0.95.
1056
+ n_workers (int): The number of concurrent workers making the LLM calls. Default is 10. The value is clamped to the range [1, 10].
1057
+ If n_workers is 1, the generation of batches becomes sequential and certain features for better data consistency are enabled.
782
1058
  return_type (Literal["auto", "dict"]): The format of the returned data. Default is "auto".
783
1059
 
784
1060
  Returns:
@@ -967,25 +1243,31 @@ def sample(
967
1243
  sample_size: dict[str, int] = _harmonize_sample_size(sample_size, config)
968
1244
  primary_keys = {table_name: table_config.primary_key for table_name, table_config in config.root.items()}
969
1245
 
1246
+ n_workers = max(min(n_workers, 10), 1)
1247
+
970
1248
  execution_plan: list[str] = _build_execution_plan(config)
971
1249
 
972
1250
  data: dict[str, pd.DataFrame] = existing_data or {}
973
1251
 
974
1252
  for table_name in execution_plan:
975
1253
  table_config = config.root[table_name]
976
- df = _sample_table(
977
- name=table_name,
978
- prompt=table_config.prompt,
979
- columns=table_config.columns,
980
- foreign_keys=table_config.foreign_keys,
981
- primary_keys=primary_keys,
982
- data=data,
983
- sample_size=sample_size[table_name],
984
- batch_size=20, # generate 20 root table rows at a time
985
- previous_rows_size=10, # present 10 previously generated rows to the LLM
986
- non_context_size=10, # pick 10 rows to choose from for each non-context foreign key
987
- llm_config=llm_config,
988
- )
1254
+
1255
+ with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
1256
+ future = executor.submit(
1257
+ _sample_table_sync,
1258
+ name=table_name,
1259
+ prompt=table_config.prompt,
1260
+ columns=table_config.columns,
1261
+ foreign_keys=table_config.foreign_keys,
1262
+ primary_keys=primary_keys,
1263
+ data=data,
1264
+ sample_size=sample_size[table_name],
1265
+ previous_rows_size=10, # present 10 previously generated rows to the LLM
1266
+ non_context_size=10, # pick 10 rows to choose from for each non-context foreign key
1267
+ n_workers=n_workers,
1268
+ llm_config=llm_config,
1269
+ )
1270
+ df = future.result()
989
1271
  data[table_name] = df
990
1272
 
991
1273
  return next(iter(data.values())) if len(data) == 1 and return_type == "auto" else data
@@ -21,7 +21,9 @@ from fastmcp import FastMCP
21
21
  from mostlyai import mock
22
22
 
23
23
  SAMPLE_MOCK_TOOL_DESCRIPTION = f"""
24
- Generate mock data by prompting an LLM.
24
+ Synthetic Mock Data.
25
+
26
+ Use LLMs to generate any Tabular Data towards your needs. Create from scratch, expand existing datasets, or enrich tables with new columns.
25
27
 
26
28
  This tool is a proxy to the `mostlyai.mock.sample` function, but returns a dictionary of paths to the generated CSV files.
27
29
 
@@ -1,7 +1,7 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mostlyai-mock
3
- Version: 0.1.7
4
- Summary: LLM-generated Mock Data
3
+ Version: 0.1.9
4
+ Summary: Synthetic Mock Data
5
5
  Project-URL: homepage, https://github.com/mostly-ai/mostlyai-mock
6
6
  Project-URL: repository, https://github.com/mostly-ai/mostlyai-mock
7
7
  Project-URL: documentation, https://mostly-ai.github.io/mostlyai-mock/
@@ -33,19 +33,20 @@ Requires-Dist: pydantic<3.0.0,>=2.0.0
33
33
  Requires-Dist: tenacity>=9.1.2
34
34
  Description-Content-Type: text/markdown
35
35
 
36
- # LLM-generated Mock Data 🔮
36
+ # Synthetic Mock Data 🔮
37
37
 
38
38
  [![Documentation](https://img.shields.io/badge/docs-latest-green)](https://mostly-ai.github.io/mostlyai-mock/) [![stats](https://pepy.tech/badge/mostlyai-mock)](https://pypi.org/project/mostlyai-mock/) ![license](https://img.shields.io/github/license/mostly-ai/mostlyai-mock) ![GitHub Release](https://img.shields.io/github/v/release/mostly-ai/mostlyai-mock)
39
39
 
40
- Create data out of nothing. Prompt LLMs for Tabular Data.
40
+ Use LLMs to generate any Tabular Data towards your needs. Create from scratch, expand existing datasets, or enrich tables with new columns. Your prompts, your rules, your data.
41
41
 
42
42
  ## Key Features
43
43
 
44
- * A light-weight python client for prompting LLMs for mixed-type tabular data
45
- * Select from a range of LLM endpoints, that provide structured output
44
+ * A light-weight python client for prompting LLMs for mixed-type tabular data.
45
+ * Select from a wide range of LLM endpoints and LLM models.
46
46
  * Supports single-table as well as multi-table scenarios.
47
47
  * Supports variety of data types: `string`, `categorical`, `integer`, `float`, `boolean`, `date`, and `datetime`.
48
48
  * Specify context, distributions and rules via dataset-, table- or column-level prompts.
49
+ * Create from scratch or enrich existing datasets with new columns and/or rows.
49
50
  * Tailor the diversity and realism of your generated data via temperature and top_p.
50
51
 
51
52
  ## Getting Started
@@ -0,0 +1,8 @@
1
+ mostlyai/mock/__init__.py,sha256=Kh6EZ-dkw2v6SuuWM0ygkKVJmBslJNP2sTsu_CE2JXM,714
2
+ mostlyai/mock/core.py,sha256=Os_nBMrwvB_noIwBMkA0R1z6jzvFcUByGEPoGIyoh78,54577
3
+ mostlyai/mock/mcp_server.py,sha256=MrVUrIsAZsFzjK1suwNl1fxS1ES-wpc-YSM8cS8Fqcw,2259
4
+ mostlyai_mock-0.1.9.dist-info/METADATA,sha256=J59K-sxrbMat0x87P1mObxjUsELCGEBYhS_k-raT_-8,14099
5
+ mostlyai_mock-0.1.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ mostlyai_mock-0.1.9.dist-info/entry_points.txt,sha256=XDbppUIAaCWW0nresVep8zb71pkzZuFA16jCBHq8CU8,61
7
+ mostlyai_mock-0.1.9.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
8
+ mostlyai_mock-0.1.9.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- mostlyai/mock/__init__.py,sha256=Cmo4Ko8-X41gSewcEpNTTvw7bpRUrtn6B5Cmnwric-Q,714
2
- mostlyai/mock/core.py,sha256=L-PbOTSIR1cfBeMZL8-v5k7VhxBfKAoyw230soBwQWc,42754
3
- mostlyai/mock/mcp_server.py,sha256=kWMIjKCwnvYfjY8B2IdP4JNs8ik_8jA6ISCDqrG9utc,2137
4
- mostlyai_mock-0.1.7.dist-info/METADATA,sha256=6tLpoqLx-LOI-Cr_O_xWm4LI5PBfa4nt1FkrqdNIpQA,13918
5
- mostlyai_mock-0.1.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- mostlyai_mock-0.1.7.dist-info/entry_points.txt,sha256=XDbppUIAaCWW0nresVep8zb71pkzZuFA16jCBHq8CU8,61
7
- mostlyai_mock-0.1.7.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
8
- mostlyai_mock-0.1.7.dist-info/RECORD,,