Flowfile 0.4.1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (139) hide show
  1. flowfile/__init__.py +3 -1
  2. flowfile/api.py +1 -2
  3. flowfile/web/static/assets/{CloudConnectionManager-d3248f8d.js → CloudConnectionManager-0dfba9f2.js} +2 -2
  4. flowfile/web/static/assets/{CloudStorageReader-d65bf041.js → CloudStorageReader-d5b1b6c9.js} +6 -6
  5. flowfile/web/static/assets/{CloudStorageWriter-e83be3ed.js → CloudStorageWriter-00d87aad.js} +6 -6
  6. flowfile/web/static/assets/{ColumnSelector-cce661cf.js → ColumnSelector-4685e75d.js} +1 -1
  7. flowfile/web/static/assets/{ContextMenu-cf18d2cc.js → ContextMenu-23e909da.js} +1 -1
  8. flowfile/web/static/assets/{ContextMenu-160afb08.js → ContextMenu-70ae0c79.js} +1 -1
  9. flowfile/web/static/assets/{ContextMenu-11a4652a.js → ContextMenu-f149cf7c.js} +1 -1
  10. flowfile/web/static/assets/{CrossJoin-d395d38c.js → CrossJoin-702a3edd.js} +7 -7
  11. flowfile/web/static/assets/{CustomNode-b812dc0b.js → CustomNode-b1519993.js} +11 -11
  12. flowfile/web/static/assets/{DatabaseConnectionSettings-7000bf2c.js → DatabaseConnectionSettings-6f3e4ea5.js} +2 -2
  13. flowfile/web/static/assets/{DatabaseManager-9662ec5b.js → DatabaseManager-cf5ef661.js} +2 -2
  14. flowfile/web/static/assets/{DatabaseReader-4f035d0c.js → DatabaseReader-d38c7295.js} +9 -9
  15. flowfile/web/static/assets/{DatabaseWriter-f65dcd54.js → DatabaseWriter-b04ef46a.js} +8 -8
  16. flowfile/web/static/assets/{ExploreData-94c43dfc.js → ExploreData-5fa10ed8.js} +5 -5
  17. flowfile/web/static/assets/{ExternalSource-ac04b3cc.js → ExternalSource-d39af878.js} +5 -5
  18. flowfile/web/static/assets/{Filter-812dcbca.js → Filter-9b6d08db.js} +7 -7
  19. flowfile/web/static/assets/{Formula-71472193.js → Formula-6b04fb1d.js} +7 -7
  20. flowfile/web/static/assets/{FuzzyMatch-b317f631.js → FuzzyMatch-999521f4.js} +8 -8
  21. flowfile/web/static/assets/{GraphSolver-754a234f.js → GraphSolver-17dd2198.js} +6 -6
  22. flowfile/web/static/assets/{GroupBy-6c6f9802.js → GroupBy-6b039e18.js} +5 -5
  23. flowfile/web/static/assets/{Join-a1b800be.js → Join-24d0f113.js} +8 -8
  24. flowfile/web/static/assets/{ManualInput-a9640276.js → ManualInput-34639209.js} +4 -4
  25. flowfile/web/static/assets/{MultiSelect-97213888.js → MultiSelect-0e8724a3.js} +2 -2
  26. flowfile/web/static/assets/{MultiSelect.vue_vue_type_script_setup_true_lang-6ffe088a.js → MultiSelect.vue_vue_type_script_setup_true_lang-b0e538c2.js} +1 -1
  27. flowfile/web/static/assets/{NumericInput-e638088a.js → NumericInput-3d63a470.js} +2 -2
  28. flowfile/web/static/assets/{NumericInput.vue_vue_type_script_setup_true_lang-90eb2cba.js → NumericInput.vue_vue_type_script_setup_true_lang-e0edeccc.js} +1 -1
  29. flowfile/web/static/assets/{Output-ddc9079f.css → Output-283fe388.css} +5 -5
  30. flowfile/web/static/assets/{Output-76750610.js → Output-edea9802.js} +57 -38
  31. flowfile/web/static/assets/{Pivot-7814803f.js → Pivot-61d19301.js} +7 -7
  32. flowfile/web/static/assets/{PivotValidation-f92137d2.js → PivotValidation-de9f43fe.js} +1 -1
  33. flowfile/web/static/assets/{PivotValidation-76dd431a.js → PivotValidation-f97fec5b.js} +1 -1
  34. flowfile/web/static/assets/{PolarsCode-889c3008.js → PolarsCode-bc3c9984.js} +5 -5
  35. flowfile/web/static/assets/{Read-637b72a7.js → Read-64a3f259.js} +80 -105
  36. flowfile/web/static/assets/{Read-6b17491f.css → Read-e808b239.css} +10 -10
  37. flowfile/web/static/assets/{RecordCount-2b050c41.js → RecordCount-3d5039be.js} +4 -4
  38. flowfile/web/static/assets/{RecordId-81df7784.js → RecordId-597510e0.js} +6 -6
  39. flowfile/web/static/assets/{SQLQueryComponent-88dcfe53.js → SQLQueryComponent-df51adbe.js} +1 -1
  40. flowfile/web/static/assets/{Sample-258ad2a9.js → Sample-4be0a507.js} +4 -4
  41. flowfile/web/static/assets/{SecretManager-2a2cb7e2.js → SecretManager-4839be57.js} +2 -2
  42. flowfile/web/static/assets/{Select-850215fd.js → Select-9b72f201.js} +7 -7
  43. flowfile/web/static/assets/{SettingsSection-29b4fa6b.js → SettingsSection-7ded385d.js} +1 -1
  44. flowfile/web/static/assets/{SettingsSection-0e8d9123.js → SettingsSection-e1e9c953.js} +1 -1
  45. flowfile/web/static/assets/{SettingsSection-55bae608.js → SettingsSection-f0f75a42.js} +1 -1
  46. flowfile/web/static/assets/{SingleSelect-bebd408b.js → SingleSelect-6c777aac.js} +2 -2
  47. flowfile/web/static/assets/{SingleSelect.vue_vue_type_script_setup_true_lang-6093741c.js → SingleSelect.vue_vue_type_script_setup_true_lang-33e3ff9b.js} +1 -1
  48. flowfile/web/static/assets/{SliderInput-6a05ab61.js → SliderInput-7cb93e62.js} +1 -1
  49. flowfile/web/static/assets/{Sort-10ab48ed.js → Sort-6cbde21a.js} +5 -5
  50. flowfile/web/static/assets/{TextInput-df9d6259.js → TextInput-d9a40c11.js} +2 -2
  51. flowfile/web/static/assets/{TextInput.vue_vue_type_script_setup_true_lang-000e1178.js → TextInput.vue_vue_type_script_setup_true_lang-5896c375.js} +1 -1
  52. flowfile/web/static/assets/{TextToRows-6c2d93d8.js → TextToRows-c4fcbf4d.js} +7 -7
  53. flowfile/web/static/assets/{ToggleSwitch-0ff7ac52.js → ToggleSwitch-4ef91d19.js} +2 -2
  54. flowfile/web/static/assets/{ToggleSwitch.vue_vue_type_script_setup_true_lang-c6dc3029.js → ToggleSwitch.vue_vue_type_script_setup_true_lang-38478c20.js} +1 -1
  55. flowfile/web/static/assets/{UnavailableFields-1bab97cb.js → UnavailableFields-a03f512c.js} +2 -2
  56. flowfile/web/static/assets/{Union-b563478a.js → Union-bfe9b996.js} +4 -4
  57. flowfile/web/static/assets/{Unique-f90db5db.js → Unique-5d023a27.js} +8 -20
  58. flowfile/web/static/assets/{Unpivot-bcb0025f.js → Unpivot-91cc5354.js} +6 -6
  59. flowfile/web/static/assets/{UnpivotValidation-c4e73b04.js → UnpivotValidation-7ee2de44.js} +1 -1
  60. flowfile/web/static/assets/{VueGraphicWalker-bb8535e2.js → VueGraphicWalker-e51b9924.js} +1 -1
  61. flowfile/web/static/assets/{api-2d6adc4f.js → api-c1bad5ca.js} +1 -1
  62. flowfile/web/static/assets/{api-4c8e3822.js → api-cf1221f0.js} +1 -1
  63. flowfile/web/static/assets/{designer-e3c150ec.css → designer-8da3ba3a.css} +90 -67
  64. flowfile/web/static/assets/{designer-f3656d8c.js → designer-9633482a.js} +119 -51
  65. flowfile/web/static/assets/{documentation-52b241e7.js → documentation-ca400224.js} +1 -1
  66. flowfile/web/static/assets/{dropDown-1bca8a74.js → dropDown-614b998d.js} +1 -1
  67. flowfile/web/static/assets/{fullEditor-2985687e.js → fullEditor-f7971590.js} +2 -2
  68. flowfile/web/static/assets/{genericNodeSettings-0476ba4e.js → genericNodeSettings-4fe5f36b.js} +3 -3
  69. flowfile/web/static/assets/{index-246f201c.js → index-5429bbf8.js} +6 -8
  70. flowfile/web/static/assets/nodeInput-5d0d6b79.js +41 -0
  71. flowfile/web/static/assets/{outputCsv-d686eeaf.js → outputCsv-076b85ab.js} +1 -1
  72. flowfile/web/static/assets/{outputExcel-8809ea2f.js → outputExcel-0fd17dbe.js} +1 -1
  73. flowfile/web/static/assets/{outputParquet-53ba645a.js → outputParquet-b61e0847.js} +1 -1
  74. flowfile/web/static/assets/{readCsv-053bf97b.js → readCsv-a8bb8b61.js} +21 -20
  75. flowfile/web/static/assets/{readCsv-bca3ed53.css → readCsv-c767cb37.css} +13 -13
  76. flowfile/web/static/assets/{readExcel-ad531eab.js → readExcel-67b4aee0.js} +10 -12
  77. flowfile/web/static/assets/{readExcel-e1b381ea.css → readExcel-806d2826.css} +12 -12
  78. flowfile/web/static/assets/{readParquet-cee068e2.css → readParquet-48c81530.css} +3 -3
  79. flowfile/web/static/assets/{readParquet-58e899a1.js → readParquet-92ce1dbc.js} +4 -7
  80. flowfile/web/static/assets/{secretApi-538058f3.js → secretApi-68435402.js} +1 -1
  81. flowfile/web/static/assets/{selectDynamic-b38de2ba.js → selectDynamic-92e25ee3.js} +3 -3
  82. flowfile/web/static/assets/{vue-codemirror.esm-db9b8936.js → vue-codemirror.esm-41b0e0d7.js} +7 -4
  83. flowfile/web/static/assets/{vue-content-loader.es-b5f3ac30.js → vue-content-loader.es-2c8e608f.js} +1 -1
  84. flowfile/web/static/index.html +1 -1
  85. {flowfile-0.4.1.dist-info → flowfile-0.5.1.dist-info}/METADATA +3 -2
  86. {flowfile-0.4.1.dist-info → flowfile-0.5.1.dist-info}/RECORD +138 -126
  87. {flowfile-0.4.1.dist-info → flowfile-0.5.1.dist-info}/WHEEL +1 -1
  88. {flowfile-0.4.1.dist-info → flowfile-0.5.1.dist-info}/entry_points.txt +1 -0
  89. flowfile_core/__init__.py +3 -0
  90. flowfile_core/flowfile/analytics/analytics_processor.py +1 -0
  91. flowfile_core/flowfile/code_generator/code_generator.py +62 -64
  92. flowfile_core/flowfile/flow_data_engine/create/funcs.py +73 -56
  93. flowfile_core/flowfile/flow_data_engine/flow_data_engine.py +77 -86
  94. flowfile_core/flowfile/flow_data_engine/fuzzy_matching/prepare_for_fuzzy_match.py +23 -23
  95. flowfile_core/flowfile/flow_data_engine/join/utils.py +1 -1
  96. flowfile_core/flowfile/flow_data_engine/join/verify_integrity.py +9 -4
  97. flowfile_core/flowfile/flow_data_engine/subprocess_operations/subprocess_operations.py +184 -78
  98. flowfile_core/flowfile/flow_data_engine/utils.py +2 -0
  99. flowfile_core/flowfile/flow_graph.py +129 -26
  100. flowfile_core/flowfile/flow_node/flow_node.py +3 -0
  101. flowfile_core/flowfile/flow_node/models.py +2 -1
  102. flowfile_core/flowfile/handler.py +5 -5
  103. flowfile_core/flowfile/manage/compatibility_enhancements.py +404 -41
  104. flowfile_core/flowfile/manage/io_flowfile.py +394 -0
  105. flowfile_core/flowfile/node_designer/__init__.py +1 -1
  106. flowfile_core/flowfile/node_designer/_type_registry.py +2 -2
  107. flowfile_core/flowfile/node_designer/custom_node.py +1 -1
  108. flowfile_core/flowfile/node_designer/ui_components.py +1 -1
  109. flowfile_core/flowfile/schema_callbacks.py +8 -5
  110. flowfile_core/flowfile/setting_generator/settings.py +15 -9
  111. flowfile_core/routes/routes.py +8 -10
  112. flowfile_core/schemas/cloud_storage_schemas.py +0 -2
  113. flowfile_core/schemas/input_schema.py +222 -65
  114. flowfile_core/schemas/output_model.py +1 -1
  115. flowfile_core/schemas/schemas.py +145 -32
  116. flowfile_core/schemas/transform_schema.py +1083 -413
  117. flowfile_core/schemas/yaml_types.py +103 -0
  118. flowfile_core/{flowfile/node_designer/data_types.py → types.py} +11 -1
  119. flowfile_frame/__init__.py +3 -1
  120. flowfile_frame/flow_frame.py +15 -18
  121. flowfile_frame/flow_frame_methods.py +12 -9
  122. flowfile_worker/__init__.py +3 -0
  123. flowfile_worker/create/__init__.py +3 -21
  124. flowfile_worker/create/funcs.py +68 -56
  125. flowfile_worker/create/models.py +130 -62
  126. flowfile_worker/routes.py +5 -8
  127. tools/migrate/README.md +56 -0
  128. tools/migrate/__init__.py +12 -0
  129. tools/migrate/__main__.py +131 -0
  130. tools/migrate/legacy_schemas.py +621 -0
  131. tools/migrate/migrate.py +598 -0
  132. tools/migrate/tests/__init__.py +0 -0
  133. tools/migrate/tests/conftest.py +23 -0
  134. tools/migrate/tests/test_migrate.py +627 -0
  135. tools/migrate/tests/test_migration_e2e.py +1010 -0
  136. tools/migrate/tests/test_node_migrations.py +813 -0
  137. flowfile_core/flowfile/manage/open_flowfile.py +0 -143
  138. {flowfile-0.4.1.dist-info → flowfile-0.5.1.dist-info}/licenses/LICENSE +0 -0
  139. /flowfile_core/flowfile/manage/manage_flowfile.py → /tools/__init__.py +0 -0
@@ -1,14 +1,22 @@
1
1
  from typing import List, Dict, Tuple, Set, Optional, Literal, Callable
2
- from dataclasses import dataclass, field
2
+ from dataclasses import asdict
3
3
  import polars as pl
4
4
  from polars import selectors
5
5
  from copy import deepcopy
6
+ from pydantic import BaseModel, ConfigDict, model_validator, Field
7
+ from typing import NamedTuple, Union, Any
8
+ from flowfile_core.schemas.yaml_types import (
9
+ SelectInputYaml, JoinInputsYaml, JoinInputYaml,
10
+ CrossJoinInputYaml, FuzzyMatchInputYaml
11
+ )
12
+ from pl_fuzzy_frame_match.models import FuzzyMapping
6
13
 
7
- from typing import NamedTuple
14
+ from flowfile_core.types import DataType, DataTypeStr
8
15
 
9
- from pl_fuzzy_frame_match.models import FuzzyMapping
16
+ FuzzyMap = FuzzyMapping
17
+
18
+ AUTO_DATA_TYPE = "Auto"
10
19
 
11
- FuzzyMap = FuzzyMapping # For backwards compatibility
12
20
 
13
21
  def get_func_type_mapping(func: str):
14
22
  """Infers the output data type of common aggregation functions."""
@@ -55,436 +63,489 @@ class FullJoinKeyResponse(NamedTuple):
55
63
  right: JoinKeyRenameResponse
56
64
 
57
65
 
58
- @dataclass
59
- class SelectInput:
66
+ class SelectInput(BaseModel):
60
67
  """Defines how a single column should be selected, renamed, or type-cast.
61
68
 
62
69
  This is a core building block for any operation that involves column manipulation.
63
70
  It holds all the configuration for a single field in a selection operation.
64
71
  """
72
+ model_config = ConfigDict(frozen=False)
73
+
65
74
  old_name: str
66
75
  original_position: Optional[int] = None
67
76
  new_name: Optional[str] = None
68
77
  data_type: Optional[str] = None
69
- data_type_change: Optional[bool] = False
70
- join_key: Optional[bool] = False
71
- is_altered: Optional[bool] = False
78
+ data_type_change: bool = False
79
+ join_key: bool = False
80
+ is_altered: bool = False
72
81
  position: Optional[int] = None
73
- is_available: Optional[bool] = True
74
- keep: Optional[bool] = True
82
+ is_available: bool = True
83
+ keep: bool = True
84
+
85
+ def __init__(self, old_name: str = None, new_name: str = None, **data):
86
+ if old_name is not None:
87
+ data['old_name'] = old_name
88
+ if new_name is not None:
89
+ data['new_name'] = new_name
90
+ super().__init__(**data)
91
+
92
+ def to_yaml_dict(self) -> SelectInputYaml:
93
+ """Serialize for YAML output - only user-relevant fields."""
94
+ result: SelectInputYaml = {"old_name": self.old_name}
95
+ if self.new_name != self.old_name:
96
+ result["new_name"] = self.new_name
97
+ if not self.keep:
98
+ result["keep"] = self.keep
99
+ if self.data_type_change and self.data_type:
100
+ result["data_type"] = self.data_type
101
+ return result
102
+
103
+ @classmethod
104
+ def from_yaml_dict(cls, data: dict) -> "SelectInput":
105
+ """Load from slim YAML format."""
106
+ old_name = data["old_name"]
107
+ new_name = data.get("new_name", old_name)
108
+ return cls(
109
+ old_name=old_name,
110
+ new_name=new_name,
111
+ keep=data.get("keep", True),
112
+ data_type=data.get("data_type"),
113
+ data_type_change=data.get("data_type") is not None,
114
+ is_altered=old_name != new_name,
115
+ )
116
+
117
+ @model_validator(mode='after')
118
+ def set_default_new_name(self):
119
+ """If new_name is None, default it to old_name."""
120
+ if self.new_name is None:
121
+ self.new_name = self.old_name
122
+ if self.old_name != self.new_name:
123
+ self.is_altered = True
124
+ return self
75
125
 
76
126
  def __hash__(self):
127
+ """Allow SelectInput to be used in sets and as dict keys."""
77
128
  return hash(self.old_name)
78
129
 
79
- def __init__(self, old_name: str, new_name: str = None, keep: bool = True, data_type: str = None,
80
- data_type_change: bool = False, join_key: bool = False, is_altered: bool = False,
81
- is_available: bool = True, position: int = None):
82
- self.old_name = old_name
83
- if new_name is None:
84
- new_name = old_name
85
- self.new_name = new_name
86
- self.keep = keep
87
- self.data_type = data_type
88
- self.data_type_change = data_type_change
89
- self.join_key = join_key
90
- self.is_altered = is_altered
91
- self.is_available = is_available
92
- self.position = position
130
+ def __eq__(self, other):
131
+ """Required when implementing __hash__."""
132
+ if not isinstance(other, SelectInput):
133
+ return False
134
+ return self.old_name == other.old_name
93
135
 
94
136
  @property
95
137
  def polars_type(self) -> str:
96
138
  """Translates a user-friendly type name to a Polars data type string."""
97
- if self.data_type.lower() == 'string':
139
+ data_type_lower = self.data_type.lower()
140
+ if data_type_lower == 'string':
98
141
  return 'Utf8'
99
- elif self.data_type.lower() == 'integer':
142
+ elif data_type_lower == 'integer':
100
143
  return 'Int64'
101
- elif self.data_type.lower() == 'double':
144
+ elif data_type_lower == 'double':
102
145
  return 'Float64'
103
146
  return self.data_type
104
147
 
105
148
 
106
- @dataclass
107
- class FieldInput:
149
+ class FieldInput(BaseModel):
108
150
  """Represents a single field with its name and data type, typically for defining an output column."""
109
151
  name: str
110
- data_type: Optional[str] = None
111
-
112
- def __init__(self, name: str, data_type: str = None):
113
- self.name = name
114
- self.data_type = data_type
152
+ data_type: DataType | Literal["Auto"] | DataTypeStr | None = AUTO_DATA_TYPE
115
153
 
116
154
 
117
- @dataclass
118
- class FunctionInput:
155
+ class FunctionInput(BaseModel):
119
156
  """Defines a formula to be applied, including the output field information."""
120
157
  field: FieldInput
121
158
  function: str
122
159
 
160
+ def __init__(self, field: FieldInput = None, function: str = None, **data):
161
+ if field is not None:
162
+ data['field'] = field
163
+ if function is not None:
164
+ data['function'] = function
165
+ super().__init__(**data)
123
166
 
124
- @dataclass
125
- class BasicFilter:
126
- """Defines a simple, single-condition filter (e.g., 'column' 'equals' 'value')."""
127
- field: str = ''
128
- filter_type: str = ''
129
- filter_value: str = ''
130
-
131
-
132
- @dataclass
133
- class FilterInput:
134
- """Defines the settings for a filter operation, supporting basic or advanced (expression-based) modes."""
135
- advanced_filter: str = ''
136
- basic_filter: BasicFilter = None
137
- filter_type: str = 'basic'
138
-
139
-
140
- @dataclass
141
- class SelectInputs:
142
- """A container for a list of `SelectInput` objects, providing helper methods for managing selections."""
143
- renames: List[SelectInput]
144
-
145
- @property
146
- def old_cols(self) -> Set:
147
- """Returns a set of original column names to be kept in the selection."""
148
- return set(v.old_name for v in self.renames if v.keep)
149
-
150
- @property
151
- def new_cols(self) -> Set:
152
- """Returns a set of new (renamed) column names to be kept in the selection."""
153
- return set(v.new_name for v in self.renames if v.keep)
154
-
155
- @property
156
- def rename_table(self):
157
- """Generates a dictionary for use in Polars' `.rename()` method."""
158
- return {v.old_name: v.new_name for v in self.renames if v.is_available and (v.keep or v.join_key)}
159
-
160
- def get_select_cols(self, include_join_key: bool = True):
161
- """Gets a list of original column names to select from the source DataFrame."""
162
- return [v.old_name for v in self.renames if v.keep or (v.join_key and include_join_key)]
163
-
164
- def has_drop_cols(self) -> bool:
165
- """Checks if any column is marked to be dropped from the selection."""
166
- return any(not v.keep for v in self.renames)
167
167
 
168
- @property
169
- def drop_columns(self) -> List[SelectInput]:
170
- """Returns a list of column names that are marked to be dropped from the selection."""
171
- return [v for v in self.renames if not v.keep and v.is_available]
168
+ class BasicFilter(BaseModel):
169
+ """Defines a simple, single-condition filter (e.g., 'column' 'equals' 'value')."""
170
+ field: Optional[str] = ''
171
+ filter_type: Optional[str] = ''
172
+ filter_value: Optional[str] = ''
172
173
 
173
- @property
174
- def non_jk_drop_columns(self) -> List[SelectInput]:
175
- return [v for v in self.renames if not v.keep and v.is_available and not v.join_key]
174
+ def __init__(self, field: str = None, filter_type: str = None, filter_value: str = None, **data):
175
+ if field is not None:
176
+ data['field'] = field
177
+ if filter_type is not None:
178
+ data['filter_type'] = filter_type
179
+ if filter_value is not None:
180
+ data['filter_value'] = filter_value
181
+ super().__init__(**data)
176
182
 
177
- def __add__(self, other: "SelectInput"):
178
- """Allows adding a SelectInput using the '+' operator."""
179
- self.renames.append(other)
180
183
 
181
- def append(self, other: "SelectInput"):
182
- """Appends a new SelectInput to the list of renames."""
183
- self.renames.append(other)
184
+ class FilterInput(BaseModel):
185
+ """Defines the settings for a filter operation, supporting basic or advanced (expression-based) modes."""
186
+ advanced_filter: Optional[str] = ''
187
+ basic_filter: Optional[BasicFilter] = None
188
+ filter_type: Optional[str] = 'basic'
189
+
190
+ def __init__(self, advanced_filter: str = None, basic_filter: BasicFilter = None,
191
+ filter_type: str = None, **data):
192
+ if advanced_filter is not None:
193
+ data['advanced_filter'] = advanced_filter
194
+ if basic_filter is not None:
195
+ data['basic_filter'] = basic_filter
196
+ if filter_type is not None:
197
+ data['filter_type'] = filter_type
198
+ super().__init__(**data)
199
+
200
+
201
+ class SelectInputs(BaseModel):
202
+ """A container for a list of `SelectInput` objects (pure data, no logic)."""
203
+ renames: List[SelectInput] = Field(default_factory=list)
204
+
205
+ def __init__(self, renames: List[SelectInput] = None, **kwargs):
206
+ if renames is not None:
207
+ kwargs['renames'] = renames
208
+ else:
209
+ kwargs['renames'] = []
210
+ super().__init__(**kwargs)
184
211
 
185
- def remove_select_input(self, old_key: str):
186
- """Removes a SelectInput from the list based on its original name."""
187
- self.renames = [rename for rename in self.renames if rename.old_name != old_key]
212
+ def to_yaml_dict(self) -> JoinInputsYaml:
213
+ """Serialize for YAML output."""
214
+ return {"select": [r.to_yaml_dict() for r in self.renames]}
188
215
 
189
- def unselect_field(self, old_key: str):
190
- """Marks a field to be dropped from the final selection by setting `keep` to False."""
191
- for rename in self.renames:
192
- if old_key == rename.old_name:
193
- rename.keep = False
216
+ @classmethod
217
+ def from_yaml_dict(cls, data: dict) -> "SelectInputs":
218
+ """Load from slim YAML format. Supports both 'select' (new) and 'renames' (internal)."""
219
+ items = data.get("select", data.get("renames", []))
220
+ return cls(renames=[SelectInput.from_yaml_dict(item) for item in items])
194
221
 
195
222
  @classmethod
196
- def create_from_list(cls, col_list: List[str]):
223
+ def create_from_list(cls, col_list: List[str]) -> "SelectInputs":
197
224
  """Creates a SelectInputs object from a simple list of column names."""
198
- return cls([SelectInput(c) for c in col_list])
225
+ return cls(renames=[SelectInput(old_name=c) for c in col_list])
199
226
 
200
227
  @classmethod
201
- def create_from_pl_df(cls, df: pl.DataFrame | pl.LazyFrame):
228
+ def create_from_pl_df(cls, df: pl.DataFrame | pl.LazyFrame) -> "SelectInputs":
202
229
  """Creates a SelectInputs object from a Polars DataFrame's columns."""
203
- return cls([SelectInput(c) for c in df.columns])
204
-
205
- def get_select_input_on_old_name(self, old_name: str) -> SelectInput | None:
206
- return next((v for v in self.renames if v.old_name == old_name), None)
207
-
208
- def get_select_input_on_new_name(self, old_name: str) -> SelectInput | None:
209
- return next((v for v in self.renames if v.new_name == old_name), None)
230
+ return cls(renames=[SelectInput(old_name=c) for c in df.columns])
210
231
 
211
232
 
212
233
  class JoinInputs(SelectInputs):
213
- """Extends `SelectInputs` with functionality specific to join operations, like handling join keys."""
214
-
215
- def __init__(self, renames: List[SelectInput]):
216
- self.renames = renames
217
-
218
- @property
219
- def join_key_selects(self) -> List[SelectInput]:
220
- """Returns only the `SelectInput` objects that are marked as join keys."""
221
- return [v for v in self.renames if v.join_key]
234
+ """Data model for join-specific select inputs (extends SelectInputs)."""
222
235
 
223
- def get_join_key_renames(self, side: SideLit, filter_drop: bool = False) -> JoinKeyRenameResponse:
224
- """Gets the temporary rename mapping for all join keys on one side of a join."""
225
- return JoinKeyRenameResponse(
226
- side,
227
- [JoinKeyRename(jk.new_name,
228
- construct_join_key_name(side, jk.new_name))
229
- for jk in self.join_key_selects if jk.keep or not filter_drop]
230
- )
231
-
232
- def get_join_key_rename_mapping(self, side: SideLit) -> Dict[str, str]:
233
- """Returns a dictionary mapping original join key names to their temporary names."""
234
- return {jkr[0]: jkr[1] for jkr in self.get_join_key_renames(side)[1]}
236
+ def __init__(self, renames: List[SelectInput] = None, **kwargs):
237
+ if renames is not None:
238
+ kwargs['renames'] = renames
239
+ else:
240
+ kwargs['renames'] = []
241
+ super().__init__(**kwargs)
235
242
 
236
243
 
237
- @dataclass
238
- class JoinMap:
244
+ class JoinMap(BaseModel):
239
245
  """Defines a single mapping between a left and right column for a join key."""
240
- left_col: str
241
- right_col: str
242
-
243
-
244
- class JoinSelectMixin:
245
- """A mixin providing common methods for join-like operations that involve left and right inputs."""
246
- left_select: JoinInputs = None
247
- right_select: JoinInputs = None
248
-
249
- @staticmethod
250
- def parse_select(select: List[SelectInput] | List[str] | List[Dict]) -> JoinInputs | None:
251
- """Parses various input formats into a standardized `JoinInputs` object."""
252
- if all(isinstance(c, SelectInput) for c in select):
253
- return JoinInputs(select)
254
- elif all(isinstance(c, dict) for c in select):
255
- return JoinInputs([SelectInput(**c.__dict__) for c in select])
256
- elif isinstance(select, dict):
257
- renames = select.get('renames')
258
- if renames:
259
- return JoinInputs([SelectInput(**c) for c in renames])
260
- elif all(isinstance(c, str) for c in select):
261
- return JoinInputs([SelectInput(s, s) for s in select])
262
-
263
- def auto_generate_new_col_name(self, old_col_name: str, side: str) -> str:
264
- """Generates a new, non-conflicting column name by adding a suffix if necessary."""
265
- current_names = self.left_select.new_cols & self.right_select.new_cols
266
- if old_col_name not in current_names:
267
- return old_col_name
268
- while True:
269
- if old_col_name not in current_names:
270
- return old_col_name
271
- old_col_name = f'{side}_{old_col_name}'
272
-
273
- def add_new_select_column(self, select_input: SelectInput, side: str):
274
- """Adds a new column to the selection for either the left or right side."""
275
- selects = self.right_select if side == 'right' else self.left_select
276
- select_input.new_name = self.auto_generate_new_col_name(select_input.old_name, side=side)
277
- selects.__add__(select_input)
278
-
279
-
280
- @dataclass
281
- class CrossJoinInput(JoinSelectMixin):
282
- """Defines the settings for a cross join operation, including column selections for both inputs."""
283
- left_select: SelectInputs = None
284
- right_select: SelectInputs = None
246
+ left_col: Optional[str] = None
247
+ right_col: Optional[str] = None
248
+
249
+ def __init__(self, left_col: str = None, right_col: str = None, **data):
250
+ if left_col is not None:
251
+ data['left_col'] = left_col
252
+ if right_col is not None:
253
+ data['right_col'] = right_col
254
+ super().__init__(**data)
255
+
256
+ @model_validator(mode='after')
257
+ def set_default_right_col(self):
258
+ """If right_col is None, default it to left_col."""
259
+ if self.right_col is None:
260
+ self.right_col = self.left_col
261
+ return self
262
+
263
+
264
+ class CrossJoinInput(BaseModel):
265
+ """Data model for cross join operations."""
266
+ left_select: JoinInputs
267
+ right_select: JoinInputs
268
+
269
+ @model_validator(mode='before')
270
+ @classmethod
271
+ def parse_inputs(cls, data: Any) -> Any:
272
+ """Parse flexible input formats before validation."""
273
+ if isinstance(data, dict):
274
+ # Parse join_mapping
275
+ if 'join_mapping' in data:
276
+ data['join_mapping'] = cls._parse_join_mapping(data['join_mapping'])
285
277
 
286
- def __init__(self, left_select: List[SelectInput] | List[str],
287
- right_select: List[SelectInput] | List[str]):
288
- """Initializes the CrossJoinInput with selections for left and right tables."""
289
- self.left_select = self.parse_select(left_select)
290
- self.right_select = self.parse_select(right_select)
278
+ # Parse left_select
279
+ if 'left_select' in data:
280
+ data['left_select'] = cls._parse_select(data['left_select'])
291
281
 
292
- @property
293
- def overlapping_records(self):
294
- """Finds column names that would conflict after the join."""
295
- return self.left_select.new_cols & self.right_select.new_cols
282
+ # Parse right_select
283
+ if 'right_select' in data:
284
+ data['right_select'] = cls._parse_select(data['right_select'])
296
285
 
297
- def auto_rename(self):
298
- """Automatically renames columns on the right side to prevent naming conflicts."""
299
- overlapping_records = self.overlapping_records
300
- while len(overlapping_records) > 0:
301
- for right_col in self.right_select.renames:
302
- if right_col.new_name in overlapping_records:
303
- right_col.new_name = 'right_' + right_col.new_name
304
- overlapping_records = self.overlapping_records
286
+ return data
305
287
 
288
+ @staticmethod
289
+ def _parse_join_mapping(join_mapping: Any) -> List[JoinMap]:
290
+ """Parse various join_mapping formats."""
291
+ # Already a list of JoinMaps
292
+ if isinstance(join_mapping, list):
293
+ result = []
294
+ for jm in join_mapping:
295
+ if isinstance(jm, JoinMap):
296
+ result.append(jm)
297
+ elif isinstance(jm, dict):
298
+ result.append(JoinMap(**jm))
299
+ elif isinstance(jm, (tuple, list)) and len(jm) == 2:
300
+ result.append(JoinMap(left_col=jm[0], right_col=jm[1]))
301
+ elif isinstance(jm, str):
302
+ result.append(JoinMap(left_col=jm, right_col=jm))
303
+ else:
304
+ raise ValueError(f"Invalid join mapping item: {jm}")
305
+ return result
306
+
307
+ # Single JoinMap
308
+ if isinstance(join_mapping, JoinMap):
309
+ return [join_mapping]
310
+
311
+ # String: same column on both sides
312
+ if isinstance(join_mapping, str):
313
+ return [JoinMap(left_col=join_mapping, right_col=join_mapping)]
314
+
315
+ # Tuple: (left, right)
316
+ if isinstance(join_mapping, tuple) and len(join_mapping) == 2:
317
+ return [JoinMap(left_col=join_mapping[0], right_col=join_mapping[1])]
318
+
319
+ raise ValueError(f"Invalid join_mapping format: {type(join_mapping)}")
306
320
 
307
- @dataclass
308
- class JoinInput(JoinSelectMixin):
309
- """Defines the settings for a standard SQL-style join, including keys, strategy, and selections."""
321
+ @staticmethod
322
+ def _parse_select(select: Any) -> JoinInputs:
323
+ """Parse various select input formats."""
324
+ # Already JoinInputs
325
+ if isinstance(select, JoinInputs):
326
+ return select
327
+
328
+ # List of SelectInput objects
329
+ if isinstance(select, list):
330
+ if all(isinstance(s, SelectInput) for s in select):
331
+ return JoinInputs(renames=select)
332
+ elif all(isinstance(s, str) for s in select):
333
+ return JoinInputs(renames=[SelectInput(old_name=s) for s in select])
334
+ elif all(isinstance(s, dict) for s in select):
335
+ return JoinInputs(renames=[SelectInput(**s) for s in select])
336
+
337
+ # Dict with 'select' (new YAML) or 'renames' (internal) key
338
+ if isinstance(select, dict):
339
+ if 'select' in select:
340
+ return JoinInputs(renames=[SelectInput.from_yaml_dict(s) for s in select['select']])
341
+ if 'renames' in select:
342
+ return JoinInputs(**select)
343
+
344
+ raise ValueError(f"Invalid select format: {type(select)}")
345
+
346
+ def __init__(self,
347
+ left_select: Union[JoinInputs, List[SelectInput], List[str]] = None,
348
+ right_select: Union[JoinInputs, List[SelectInput], List[str]] = None,
349
+ **data):
350
+ """Custom init for backward compatibility with positional arguments."""
351
+ if left_select is not None:
352
+ data['left_select'] = left_select
353
+ if right_select is not None:
354
+ data['right_select'] = right_select
355
+ super().__init__(**data)
356
+
357
+ def to_yaml_dict(self) -> CrossJoinInputYaml:
358
+ """Serialize for YAML output."""
359
+ return {
360
+ "left_select": self.left_select.to_yaml_dict(),
361
+ "right_select": self.right_select.to_yaml_dict(),
362
+ }
363
+
364
+
365
+ class JoinInput(BaseModel):
366
+ """Data model for standard SQL-style join operations."""
310
367
  join_mapping: List[JoinMap]
311
- left_select: JoinInputs = None
312
- right_select: JoinInputs = None
368
+ left_select: JoinInputs
369
+ right_select: JoinInputs
313
370
  how: JoinStrategy = 'inner'
314
371
 
315
- @staticmethod
316
- def parse_join_mapping(join_mapping: any) -> List[JoinMap]:
317
- """Parses various input formats for join keys into a standardized list of `JoinMap` objects."""
318
- if isinstance(join_mapping, (tuple, list)):
319
- assert len(join_mapping) > 0
320
- if all(isinstance(jm, dict) for jm in join_mapping):
321
- join_mapping = [JoinMap(**jm) for jm in join_mapping]
322
-
323
- if not isinstance(join_mapping[0], JoinMap):
324
- assert len(join_mapping) <= 2
325
- if len(join_mapping) == 2:
326
- assert isinstance(join_mapping[0], str) and isinstance(join_mapping[1], str)
327
- join_mapping = [JoinMap(*join_mapping)]
328
- elif isinstance(join_mapping[0], str):
329
- join_mapping = [JoinMap(join_mapping[0], join_mapping[0])]
330
- elif isinstance(join_mapping, str):
331
- join_mapping = [JoinMap(join_mapping, join_mapping)]
332
- else:
333
- raise Exception('No valid join mapping as input')
334
- return join_mapping
335
-
336
- def __init__(self, join_mapping: List[JoinMap] | Tuple[str, str] | str,
337
- left_select: List[SelectInput] | List[str],
338
- right_select: List[SelectInput] | List[str],
339
- how: JoinStrategy = 'inner'):
340
- """Initializes the JoinInput with keys, selections, and join strategy."""
341
- self.join_mapping = self.parse_join_mapping(join_mapping)
342
- self.left_select = self.parse_select(left_select)
343
- self.right_select = self.parse_select(right_select)
344
- self.set_join_keys()
345
- self.how = how
346
-
347
- def set_join_keys(self):
348
- """Marks the `SelectInput` objects corresponding to join keys."""
349
- [setattr(v, "join_key", v.old_name in self._left_join_keys) for v in self.left_select.renames]
350
- [setattr(v, "join_key", v.old_name in self._right_join_keys) for v in self.right_select.renames]
351
-
352
- def get_join_key_renames(self, filter_drop: bool = False) -> FullJoinKeyResponse:
353
- """Gets the temporary rename mappings for the join keys on both sides."""
354
- return FullJoinKeyResponse(self.left_select.get_join_key_renames(side="left", filter_drop=filter_drop),
355
- self.right_select.get_join_key_renames(side="right", filter_drop=filter_drop))
356
-
357
- def get_names_for_table_rename(self) -> List[JoinMap]:
358
- new_mappings: List[JoinMap] = []
359
- left_rename_table, right_rename_table = self.left_select.rename_table, self.right_select.rename_table
360
- for join_map in self.join_mapping:
361
- new_mappings.append(JoinMap(left_rename_table.get(join_map.left_col, join_map.left_col),
362
- right_rename_table.get(join_map.right_col, join_map.right_col)
363
- )
364
- )
365
- return new_mappings
366
-
367
- @property
368
- def _left_join_keys(self) -> Set:
369
- """Returns a set of the left-side join key column names."""
370
- return set(jm.left_col for jm in self.join_mapping)
371
-
372
- @property
373
- def _right_join_keys(self) -> Set:
374
- """Returns a set of the right-side join key column names."""
375
- return set(jm.right_col for jm in self.join_mapping)
372
+ @model_validator(mode='before')
373
+ @classmethod
374
+ def parse_inputs(cls, data: Any) -> Any:
375
+ """Parse flexible input formats before validation."""
376
+ if isinstance(data, dict):
377
+ # Parse join_mapping
378
+ if 'join_mapping' in data:
379
+ data['join_mapping'] = cls._parse_join_mapping(data['join_mapping'])
376
380
 
377
- @property
378
- def left_join_keys(self) -> List[str]:
379
- """Returns an ordered list of the left-side join key column names to be used in the join."""
380
- return [jm.left_col for jm in self.used_join_mapping]
381
+ # Parse left_select
382
+ if 'left_select' in data:
383
+ data['left_select'] = cls._parse_select(data['left_select'])
381
384
 
382
- @property
383
- def right_join_keys(self) -> List[str]:
384
- """Returns an ordered list of the right-side join key column names to be used in the join."""
385
- return [jm.right_col for jm in self.used_join_mapping]
385
+ # Parse right_select
386
+ if 'right_select' in data:
387
+ data['right_select'] = cls._parse_select(data['right_select'])
386
388
 
387
- @property
388
- def overlapping_records(self):
389
- if self.how in ('left', 'right', 'inner'):
390
- return self.left_select.new_cols & self.right_select.new_cols
391
- else:
392
- return self.left_select.new_cols & self.right_select.new_cols
393
-
394
- def auto_rename(self):
395
- """Automatically renames columns on the right side to prevent naming conflicts."""
396
- self.set_join_keys()
397
- overlapping_records = self.overlapping_records
398
- while len(overlapping_records) > 0:
399
- for right_col in self.right_select.renames:
400
- if right_col.new_name in overlapping_records:
401
- right_col.new_name = right_col.new_name + '_right'
402
- overlapping_records = self.overlapping_records
403
-
404
- @property
405
- def used_join_mapping(self) -> List[JoinMap]:
406
- """Returns the final join mapping after applying all renames and transformations."""
407
- new_mappings: List[JoinMap] = []
408
- left_rename_table, right_rename_table = self.left_select.rename_table, self.right_select.rename_table
409
- left_join_rename_mapping: Dict[str, str] = self.left_select.get_join_key_rename_mapping("left")
410
- right_join_rename_mapping: Dict[str, str] = self.right_select.get_join_key_rename_mapping("right")
411
- for join_map in self.join_mapping:
412
- # del self.right_select.rename_table, self.left_select.rename_table
413
- new_mappings.append(JoinMap(left_join_rename_mapping.get(left_rename_table.get(join_map.left_col, join_map.left_col)),
414
- right_join_rename_mapping.get(right_rename_table.get(join_map.right_col, join_map.right_col))
415
- )
416
- )
417
- return new_mappings
389
+ return data
418
390
 
391
+ @staticmethod
392
+ def _parse_join_mapping(join_mapping: Any) -> List[JoinMap]:
393
+ """Parse various join_mapping formats."""
394
+ # Already a list of JoinMaps
395
+ if isinstance(join_mapping, list):
396
+ result = []
397
+ for jm in join_mapping:
398
+ if isinstance(jm, JoinMap):
399
+ result.append(jm)
400
+ elif isinstance(jm, dict):
401
+ result.append(JoinMap(**jm))
402
+ elif isinstance(jm, (tuple, list)) and len(jm) == 2:
403
+ result.append(JoinMap(left_col=jm[0], right_col=jm[1]))
404
+ elif isinstance(jm, str):
405
+ result.append(JoinMap(left_col=jm, right_col=jm))
406
+ else:
407
+ raise ValueError(f"Invalid join mapping item: {jm}")
408
+ return result
409
+
410
+ # Single JoinMap
411
+ if isinstance(join_mapping, JoinMap):
412
+ return [join_mapping]
413
+
414
+ # String: same column on both sides
415
+ if isinstance(join_mapping, str):
416
+ return [JoinMap(left_col=join_mapping, right_col=join_mapping)]
417
+
418
+ # Tuple: (left, right)
419
+ if isinstance(join_mapping, tuple) and len(join_mapping) == 2:
420
+ return [JoinMap(left_col=join_mapping[0], right_col=join_mapping[1])]
421
+
422
+ raise ValueError(f"Invalid join_mapping format: {type(join_mapping)}")
419
423
 
420
- @dataclass
421
- class FuzzyMatchInput(JoinInput):
422
- """Extends `JoinInput` with settings specific to fuzzy matching, such as the matching algorithm and threshold."""
424
+ @staticmethod
425
+ def _parse_select(select: Any) -> JoinInputs:
426
+ """Parse various select input formats."""
427
+ # Already JoinInputs
428
+ if isinstance(select, JoinInputs):
429
+ return select
430
+
431
+ # List of SelectInput objects
432
+ if isinstance(select, list):
433
+ if all(isinstance(s, SelectInput) for s in select):
434
+ return JoinInputs(renames=select)
435
+ elif all(isinstance(s, str) for s in select):
436
+ return JoinInputs(renames=[SelectInput(old_name=s) for s in select])
437
+ elif all(isinstance(s, dict) for s in select):
438
+ return JoinInputs(renames=[SelectInput(**s) for s in select])
439
+
440
+ # Dict with 'select' (new YAML) or 'renames' (internal) key
441
+ if isinstance(select, dict):
442
+ if 'select' in select:
443
+ return JoinInputs(renames=[SelectInput.from_yaml_dict(s) for s in select['select']])
444
+ if 'renames' in select:
445
+ return JoinInputs(**select)
446
+
447
+ raise ValueError(f"Invalid select format: {type(select)}")
448
+
449
+ def __init__(self,
450
+ join_mapping: Union[List[JoinMap], JoinMap, Tuple[str, str], str, List[Tuple], List[str]] = None,
451
+ left_select: Union[JoinInputs, List[SelectInput], List[str]] = None,
452
+ right_select: Union[JoinInputs, List[SelectInput], List[str]] = None,
453
+ how: JoinStrategy = 'inner',
454
+ **data):
455
+ """Custom init for backward compatibility with positional arguments."""
456
+ if join_mapping is not None:
457
+ data['join_mapping'] = join_mapping
458
+ if left_select is not None:
459
+ data['left_select'] = left_select
460
+ if right_select is not None:
461
+ data['right_select'] = right_select
462
+ if how is not None:
463
+ data['how'] = how
464
+
465
+ super().__init__(**data)
466
+
467
+ def to_yaml_dict(self) -> JoinInputYaml:
468
+ """Serialize for YAML output."""
469
+ return {
470
+ "join_mapping": [{"left_col": jm.left_col, "right_col": jm.right_col} for jm in self.join_mapping],
471
+ "left_select": self.left_select.to_yaml_dict(),
472
+ "right_select": self.right_select.to_yaml_dict(),
473
+ "how": self.how,
474
+ }
475
+
476
+
477
+ class FuzzyMatchInput(BaseModel):
478
+ """Data model for fuzzy matching join operations."""
423
479
  join_mapping: List[FuzzyMapping]
480
+ left_select: JoinInputs
481
+ right_select: JoinInputs
482
+ how: JoinStrategy = 'inner'
424
483
  aggregate_output: bool = False
425
484
 
426
- @staticmethod
427
- def parse_fuzz_mapping(fuzz_mapping: List[FuzzyMapping] | Tuple[str, str] | str) -> List[FuzzyMapping]:
428
- if isinstance(fuzz_mapping, (tuple, list)):
429
- assert len(fuzz_mapping) > 0
430
- if all(isinstance(fm, dict) for fm in fuzz_mapping):
431
- fuzz_mapping = [FuzzyMapping(**fm) for fm in fuzz_mapping]
485
+ def __init__(self,
486
+ left_select: Union[JoinInputs, List[SelectInput], List[str]] = None,
487
+ right_select: Union[JoinInputs, List[SelectInput], List[str]] = None,
488
+ **data):
489
+ """Custom init for backward compatibility with positional arguments."""
490
+ if left_select is not None:
491
+ data['left_select'] = left_select
492
+ if right_select is not None:
493
+ data['right_select'] = right_select
494
+
495
+ super().__init__(**data)
496
+
497
+ def to_yaml_dict(self) -> FuzzyMatchInputYaml:
498
+ """Serialize for YAML output."""
499
+ return {
500
+ "join_mapping": [asdict(jm) for jm in self.join_mapping],
501
+ "left_select": self.left_select.to_yaml_dict(),
502
+ "right_select": self.right_select.to_yaml_dict(),
503
+ "how": self.how,
504
+ "aggregate_output": self.aggregate_output,
505
+ }
432
506
 
433
- if not isinstance(fuzz_mapping[0], FuzzyMapping):
434
- assert len(fuzz_mapping) <= 2
435
- if len(fuzz_mapping) == 2:
436
- assert isinstance(fuzz_mapping[0], str) and isinstance(fuzz_mapping[1], str)
437
- fuzz_mapping = [FuzzyMapping(*fuzz_mapping)]
438
- elif isinstance(fuzz_mapping[0], str):
439
- fuzz_mapping = [FuzzyMapping(fuzz_mapping[0], fuzz_mapping[0])]
440
- elif isinstance(fuzz_mapping, str):
441
- fuzz_mapping = [FuzzyMapping(fuzz_mapping, fuzz_mapping)]
442
- elif isinstance(fuzz_mapping, FuzzyMapping):
443
- fuzz_mapping = [fuzz_mapping]
444
- else:
445
- raise Exception('No valid join mapping as input')
446
- return fuzz_mapping
447
-
448
- def __init__(self, join_mapping: List[FuzzyMapping] | Tuple[str, str] | str, left_select: List[SelectInput] | List[str],
449
- right_select: List[SelectInput] | List[str], aggregate_output: bool = False, how: JoinStrategy = 'inner'):
450
- self.join_mapping = self.parse_fuzz_mapping(join_mapping)
451
- self.left_select = self.parse_select(left_select)
452
- self.right_select = self.parse_select(right_select)
453
- self.how = how
454
- for jm in self.join_mapping:
455
-
456
- if jm.right_col not in {v.old_name for v in self.right_select.renames}:
457
- self.right_select.append(SelectInput(jm.right_col, keep=False, join_key=True))
458
- if jm.left_col not in {v.old_name for v in self.left_select.renames}:
459
- self.left_select.append(SelectInput(jm.left_col, keep=False, join_key=True))
460
- [setattr(v, "join_key", v.old_name in self._left_join_keys) for v in self.left_select.renames]
461
- [setattr(v, "join_key", v.old_name in self._right_join_keys) for v in self.right_select.renames]
462
- self.aggregate_output = aggregate_output
507
+ @staticmethod
508
+ def _parse_select(select: Any) -> JoinInputs:
509
+ """Parse various select input formats."""
510
+ # Already JoinInputs
511
+ if isinstance(select, JoinInputs):
512
+ return select
513
+
514
+ # List of SelectInput objects
515
+ if isinstance(select, list):
516
+ if all(isinstance(s, SelectInput) for s in select):
517
+ return JoinInputs(renames=select)
518
+ elif all(isinstance(s, str) for s in select):
519
+ return JoinInputs(renames=[SelectInput(old_name=s) for s in select])
520
+ elif all(isinstance(s, dict) for s in select):
521
+ return JoinInputs(renames=[SelectInput(**s) for s in select])
522
+
523
+ # Dict with 'select' (new YAML) or 'renames' (internal) key
524
+ if isinstance(select, dict):
525
+ if 'select' in select:
526
+ return JoinInputs(renames=[SelectInput.from_yaml_dict(s) for s in select['select']])
527
+ if 'renames' in select:
528
+ return JoinInputs(**select)
529
+
530
+ raise ValueError(f"Invalid select format: {type(select)}")
531
+
532
+ @model_validator(mode='before')
533
+ @classmethod
534
+ def parse_inputs(cls, data: Any) -> Any:
535
+ """Parse flexible input formats before validation."""
536
+ if isinstance(data, dict):
537
+ # Parse left_select
538
+ if 'left_select' in data:
539
+ data['left_select'] = cls._parse_select(data['left_select'])
463
540
 
464
- @property
465
- def overlapping_records(self):
466
- return self.left_select.new_cols & self.right_select.new_cols
541
+ # Parse right_select
542
+ if 'right_select' in data:
543
+ data['right_select'] = cls._parse_select(data['right_select'])
467
544
 
468
- @property
469
- def fuzzy_maps(self) -> List[FuzzyMapping]:
470
- """Returns the final fuzzy mappings after applying all column renames."""
471
- new_mappings = []
472
- left_rename_table, right_rename_table = self.left_select.rename_table, self.right_select.rename_table
473
- for org_fuzzy_map in self.join_mapping:
474
- right_col = right_rename_table.get(org_fuzzy_map.right_col)
475
- left_col = left_rename_table.get(org_fuzzy_map.left_col)
476
- if right_col != org_fuzzy_map.right_col or left_col != org_fuzzy_map.left_col:
477
- new_mapping = deepcopy(org_fuzzy_map)
478
- new_mapping.left_col = left_col
479
- new_mapping.right_col = right_col
480
- new_mappings.append(new_mapping)
481
- else:
482
- new_mappings.append(org_fuzzy_map)
483
- return new_mappings
545
+ return data
484
546
 
485
547
 
486
- @dataclass
487
- class AggColl:
548
+ class AggColl(BaseModel):
488
549
  """
489
550
  A data class that represents a single aggregation operation for a group by operation.
490
551
 
@@ -493,7 +554,7 @@ class AggColl:
493
554
  old_name : str
494
555
  The name of the column in the original DataFrame to be aggregated.
495
556
 
496
- agg : Any
557
+ agg : str
497
558
  The aggregation function to use. This can be a string representing a built-in function or a custom function.
498
559
 
499
560
  new_name : Optional[str]
@@ -515,18 +576,36 @@ class AggColl:
515
576
  """
516
577
  old_name: str
517
578
  agg: str
518
- new_name: Optional[str]
579
+ new_name: Optional[str] = None
519
580
  output_type: Optional[str] = None
520
581
 
521
- def __init__(self, old_name: str, agg: str, new_name: str = None, output_type: str = None):
522
- """Initializes an aggregation column with its source, function, and new name."""
523
- self.old_name = str(old_name)
524
- if agg != 'groupby':
525
- self.new_name = new_name if new_name is not None else self.old_name + "_" + agg
526
- else:
527
- self.new_name = new_name if new_name is not None else self.old_name
528
- self.output_type = output_type if output_type is not None else get_func_type_mapping(agg)
529
- self.agg = agg
582
+ def __init__(self, old_name: str, agg: str, new_name: Optional[str] = None, output_type: Optional[str] = None):
583
+ data = {'old_name': old_name, 'agg': agg}
584
+ if new_name is not None:
585
+ data['new_name'] = new_name
586
+ if output_type is not None:
587
+ data['output_type'] = output_type
588
+
589
+ super().__init__(**data)
590
+
591
+ @model_validator(mode='after')
592
+ def set_defaults(self):
593
+ """Set default new_name and output_type based on agg function."""
594
+ # Set new_name
595
+ if self.new_name is None:
596
+ if self.agg != 'groupby':
597
+ self.new_name = self.old_name + "_" + self.agg
598
+ else:
599
+ self.new_name = self.old_name
600
+
601
+ # Set output_type
602
+ if self.output_type is None:
603
+ self.output_type = get_func_type_mapping(self.agg)
604
+
605
+ # Ensure old_name is a string
606
+ self.old_name = str(self.old_name)
607
+
608
+ return self
530
609
 
531
610
  @property
532
611
  def agg_func(self):
@@ -539,16 +618,12 @@ class AggColl:
539
618
  return getattr(pl, self.agg) if isinstance(self.agg, str) else self.agg
540
619
 
541
620
 
542
- @dataclass
543
- class GroupByInput:
621
+ class GroupByInput(BaseModel):
544
622
  """
545
623
  A data class that represents the input for a group by operation.
546
624
 
547
625
  Attributes
548
626
  ----------
549
- group_columns : List[str]
550
- A list of column names to group the DataFrame by. These column(s) will be set as the DataFrame index.
551
-
552
627
  agg_cols : List[AggColl]
553
628
  A list of `AggColl` objects that specify the aggregation operations to perform on the DataFrame columns
554
629
  after grouping. Each `AggColl` object should specify the column to be aggregated and the aggregation
@@ -557,14 +632,18 @@ class GroupByInput:
557
632
  Example
558
633
  --------
559
634
  group_by_input = GroupByInput(
560
- agg_cols=[AggColl(old_name='ix', agg='groupby'), AggColl(old_name='groups', agg='groupby'), AggColl(old_name='col1', agg='sum'), AggColl(old_name='col2', agg='mean')]
635
+ agg_cols=[AggColl(old_name='ix', agg='groupby'), AggColl(old_name='groups', agg='groupby'),
636
+ AggColl(old_name='col1', agg='sum'), AggColl(old_name='col2', agg='mean')]
561
637
  )
562
638
  """
563
639
  agg_cols: List[AggColl]
564
640
 
641
+ def __init__(self, agg_cols: List[AggColl]):
642
+ """Backwards compatibility implementation"""
643
+ super().__init__(agg_cols=agg_cols)
565
644
 
566
- @dataclass
567
- class PivotInput:
645
+
646
+ class PivotInput(BaseModel):
568
647
  """Defines the settings for a pivot (long-to-wide) operation."""
569
648
  index_columns: List[str]
570
649
  pivot_column: str
@@ -578,11 +657,13 @@ class PivotInput:
578
657
 
579
658
  def get_group_by_input(self) -> GroupByInput:
580
659
  """Constructs the `GroupByInput` needed for the pre-aggregation step of the pivot."""
581
- group_by_cols = [AggColl(c, 'groupby') for c in self.grouped_columns]
582
- agg_cols = [AggColl(self.value_col, agg=aggregation, new_name=aggregation) for aggregation in self.aggregations]
583
- return GroupByInput(group_by_cols+agg_cols)
660
+ group_by_cols = [AggColl(old_name=c, agg='groupby') for c in self.grouped_columns]
661
+ agg_cols = [AggColl(old_name=self.value_col, agg=aggregation, new_name=aggregation)
662
+ for aggregation in self.aggregations]
663
+ return GroupByInput(agg_cols=group_by_cols + agg_cols)
584
664
 
585
665
  def get_index_columns(self) -> List[pl.col]:
666
+ """Returns the index columns as Polars column expressions."""
586
667
  return [pl.col(c) for c in self.index_columns]
587
668
 
588
669
  def get_pivot_column(self) -> pl.Expr:
@@ -594,24 +675,21 @@ class PivotInput:
594
675
  return pl.struct([pl.col(c) for c in self.aggregations]).alias('vals')
595
676
 
596
677
 
597
- @dataclass
598
- class SortByInput:
678
+ class SortByInput(BaseModel):
599
679
  """Defines a single sort condition on a column, including the direction."""
600
680
  column: str
601
- how: str = 'asc'
681
+ how: Optional[str] = 'asc'
602
682
 
603
683
 
604
- @dataclass
605
- class RecordIdInput:
684
+ class RecordIdInput(BaseModel):
606
685
  """Defines settings for adding a record ID (row number) column to the data."""
607
686
  output_column_name: str = 'record_id'
608
687
  offset: int = 1
609
688
  group_by: Optional[bool] = False
610
- group_by_columns: Optional[List[str]] = field(default_factory=list)
689
+ group_by_columns: Optional[List[str]] = Field(default_factory=list)
611
690
 
612
691
 
613
- @dataclass
614
- class TextToRowsInput:
692
+ class TextToRowsInput(BaseModel):
615
693
  """Defines settings for splitting a text column into multiple rows based on a delimiter."""
616
694
  column_to_split: str
617
695
  output_column_name: Optional[str] = None
@@ -620,22 +698,14 @@ class TextToRowsInput:
620
698
  split_by_column: Optional[str] = None
621
699
 
622
700
 
623
- @dataclass
624
- class UnpivotInput:
701
+ class UnpivotInput(BaseModel):
625
702
  """Defines settings for an unpivot (wide-to-long) operation."""
626
- index_columns: Optional[List[str]] = field(default_factory=list)
627
- value_columns: Optional[List[str]] = field(default_factory=list)
628
- data_type_selector: Optional[Literal['float', 'all', 'date', 'numeric', 'string']] = None
629
- data_type_selector_mode: Optional[Literal['data_type', 'column']] = 'column'
703
+ model_config = ConfigDict(arbitrary_types_allowed=True)
630
704
 
631
- def __post_init__(self):
632
- """Ensures that list attributes are initialized correctly if they are None."""
633
- if self.index_columns is None:
634
- self.index_columns = []
635
- if self.value_columns is None:
636
- self.value_columns = []
637
- if self.data_type_selector_mode is None:
638
- self.data_type_selector_mode = 'column'
705
+ index_columns: List[str] = Field(default_factory=list)
706
+ value_columns: List[str] = Field(default_factory=list)
707
+ data_type_selector: Optional[Literal['float', 'all', 'date', 'numeric', 'string']] = None
708
+ data_type_selector_mode: Literal['data_type', 'column'] = 'column'
639
709
 
640
710
  @property
641
711
  def data_type_selector_expr(self) -> Optional[Callable]:
@@ -648,30 +718,630 @@ class UnpivotInput:
648
718
  print(f'Could not find the selector: {self.data_type_selector}')
649
719
  return selectors.all
650
720
  return selectors.all
721
+ return None
651
722
 
652
723
 
653
- @dataclass
654
- class UnionInput:
724
+ class UnionInput(BaseModel):
655
725
  """Defines settings for a union (concatenation) operation."""
656
726
  mode: Literal['selective', 'relaxed'] = 'relaxed'
657
727
 
658
728
 
659
- @dataclass
660
- class UniqueInput:
729
+ class UniqueInput(BaseModel):
661
730
  """Defines settings for a uniqueness operation, specifying columns and which row to keep."""
662
731
  columns: Optional[List[str]] = None
663
732
  strategy: Literal["first", "last", "any", "none"] = "any"
664
733
 
665
734
 
666
- @dataclass
667
- class GraphSolverInput:
735
+ class GraphSolverInput(BaseModel):
668
736
  """Defines settings for a graph-solving operation (e.g., finding connected components)."""
669
737
  col_from: str
670
738
  col_to: str
671
739
  output_column_name: Optional[str] = 'graph_group'
672
740
 
673
741
 
674
- @dataclass
675
- class PolarsCodeInput:
742
+ class PolarsCodeInput(BaseModel):
676
743
  """A simple container for a string of user-provided Polars code to be executed."""
677
744
  polars_code: str
745
+
746
+
747
+ class SelectInputsManager:
748
+ """Manager class that provides all query and mutation operations."""
749
+
750
+ def __init__(self, select_inputs: SelectInputs):
751
+ self.select_inputs = select_inputs
752
+
753
+ # === Query Methods (read-only) ===
754
+
755
+ def get_old_cols(self) -> Set[str]:
756
+ """Returns a set of original column names to be kept in the selection."""
757
+ return set(v.old_name for v in self.select_inputs.renames if v.keep)
758
+
759
+ def get_new_cols(self) -> Set[str]:
760
+ """Returns a set of new (renamed) column names to be kept in the selection."""
761
+ return set(v.new_name for v in self.select_inputs.renames if v.keep)
762
+
763
+ def get_rename_table(self) -> dict[str, str]:
764
+ """Generates a dictionary for use in Polars' `.rename()` method."""
765
+ return {v.old_name: v.new_name for v in self.select_inputs.renames
766
+ if v.is_available and (v.keep or v.join_key)}
767
+
768
+ def get_select_cols(self, include_join_key: bool = True) -> List[str]:
769
+ """Gets a list of original column names to select from the source DataFrame."""
770
+ return [v.old_name for v in self.select_inputs.renames
771
+ if v.keep or (v.join_key and include_join_key)]
772
+
773
+ def has_drop_cols(self) -> bool:
774
+ """Checks if any column is marked to be dropped from the selection."""
775
+ return any(not v.keep for v in self.select_inputs.renames)
776
+
777
+ def get_drop_columns(self) -> List[SelectInput]:
778
+ """Returns a list of SelectInput objects that are marked to be dropped."""
779
+ return [v for v in self.select_inputs.renames if not v.keep and v.is_available]
780
+
781
+ def get_non_jk_drop_columns(self) -> List[SelectInput]:
782
+ """Returns drop columns that are not join keys."""
783
+ return [v for v in self.select_inputs.renames
784
+ if not v.keep and v.is_available and not v.join_key]
785
+
786
+ def find_by_old_name(self, old_name: str) -> Optional[SelectInput]:
787
+ """Find SelectInput by original column name."""
788
+ return next((v for v in self.select_inputs.renames if v.old_name == old_name), None)
789
+
790
+ def find_by_new_name(self, new_name: str) -> Optional[SelectInput]:
791
+ """Find SelectInput by new column name."""
792
+ return next((v for v in self.select_inputs.renames if v.new_name == new_name), None)
793
+
794
+ # === Mutation Methods ===
795
+
796
+ def append(self, other: SelectInput) -> None:
797
+ """Appends a new SelectInput to the list of renames."""
798
+ self.select_inputs.renames.append(other)
799
+
800
+ def remove_select_input(self, old_key: str) -> None:
801
+ """Removes a SelectInput from the list based on its original name."""
802
+ self.select_inputs.renames = [
803
+ rename for rename in self.select_inputs.renames
804
+ if rename.old_name != old_key
805
+ ]
806
+
807
+ def unselect_field(self, old_key: str) -> None:
808
+ """Marks a field to be dropped from the final selection by setting `keep` to False."""
809
+ for rename in self.select_inputs.renames:
810
+ if old_key == rename.old_name:
811
+ rename.keep = False
812
+
813
+ # === Backward Compatibility Properties ===
814
+
815
+ @property
816
+ def old_cols(self) -> Set[str]:
817
+ """Backward compatibility: Returns set of old column names."""
818
+ return self.get_old_cols()
819
+
820
+ @property
821
+ def new_cols(self) -> Set[str]:
822
+ """Backward compatibility: Returns set of new column names."""
823
+ return self.get_new_cols()
824
+
825
+ @property
826
+ def rename_table(self) -> dict[str, str]:
827
+ """Backward compatibility: Returns rename table dictionary."""
828
+ return self.get_rename_table()
829
+
830
+ @property
831
+ def drop_columns(self) -> List[SelectInput]:
832
+ """Backward compatibility: Returns list of columns to drop."""
833
+ return self.get_drop_columns()
834
+
835
+ @property
836
+ def non_jk_drop_columns(self) -> List[SelectInput]:
837
+ """Backward compatibility: Returns non-join-key columns to drop."""
838
+ return self.get_non_jk_drop_columns()
839
+
840
+ @property
841
+ def renames(self) -> List[SelectInput]:
842
+ """Backward compatibility: Direct access to renames list."""
843
+ return self.select_inputs.renames
844
+
845
+ def get_select_input_on_old_name(self, old_name: str) -> Optional[SelectInput]:
846
+ """Backward compatibility alias: Find SelectInput by original column name."""
847
+ return self.find_by_old_name(old_name)
848
+
849
+ def get_select_input_on_new_name(self, new_name: str) -> Optional[SelectInput]:
850
+ """Backward compatibility alias: Find SelectInput by new column name."""
851
+ return self.find_by_new_name(new_name)
852
+
853
+ def __add__(self, other: SelectInput) -> "SelectInputsManager":
854
+ """Backward compatibility: Support += operator for appending."""
855
+ self.append(other)
856
+ return self
857
+
858
+
859
+ class JoinInputsManager(SelectInputsManager):
860
+ """Manager for join-specific operations, extends SelectInputsManager."""
861
+
862
+ def __init__(self, join_inputs: JoinInputs):
863
+ super().__init__(join_inputs)
864
+ self.join_inputs = join_inputs
865
+
866
+ # === Query Methods ===
867
+
868
+ def get_join_key_selects(self) -> List[SelectInput]:
869
+ """Returns only the `SelectInput` objects that are marked as join keys."""
870
+ return [v for v in self.join_inputs.renames if v.join_key]
871
+
872
+ def get_join_key_renames(self, side: SideLit, filter_drop: bool = False) -> JoinKeyRenameResponse:
873
+ """Gets the temporary rename mapping for all join keys on one side of a join."""
874
+ join_key_selects = self.get_join_key_selects()
875
+ join_key_list = [
876
+ JoinKeyRename(jk.new_name, construct_join_key_name(side, jk.new_name))
877
+ for jk in join_key_selects
878
+ if jk.keep or not filter_drop
879
+ ]
880
+ return JoinKeyRenameResponse(side, join_key_list)
881
+
882
+ def get_join_key_rename_mapping(self, side: SideLit) -> Dict[str, str]:
883
+ """Returns a dictionary mapping original join key names to their temporary names."""
884
+ join_key_response = self.get_join_key_renames(side)
885
+ return {jkr.original_name: jkr.temp_name for jkr in join_key_response.join_key_renames}
886
+
887
+ @property
888
+ def join_key_selects(self) -> List[SelectInput]:
889
+ """Backward compatibility: Returns join key SelectInputs."""
890
+ return self.get_join_key_selects()
891
+
892
+
893
+ class JoinSelectManagerMixin:
894
+ """Mixin providing common methods for join-like operations."""
895
+
896
+ left_manager: JoinInputsManager
897
+ right_manager: JoinInputsManager
898
+ input: Union[CrossJoinInput, JoinInput, FuzzyMatchInput]
899
+
900
+ @staticmethod
901
+ def parse_select(select: Union[List[SelectInput], List[str], List[Dict], Dict]) -> JoinInputs:
902
+ """Parses various input formats into a standardized `JoinInputs` object."""
903
+ if not select:
904
+ return JoinInputs(renames=[])
905
+
906
+ if all(isinstance(c, SelectInput) for c in select):
907
+ return JoinInputs(renames=select)
908
+ elif all(isinstance(c, dict) for c in select):
909
+ return JoinInputs(renames=[SelectInput(**c) for c in select])
910
+ elif isinstance(select, dict):
911
+ renames = select.get('renames')
912
+ if renames:
913
+ return JoinInputs(renames=[SelectInput(**c) for c in renames])
914
+ return JoinInputs(renames=[])
915
+ elif all(isinstance(c, str) for c in select):
916
+ return JoinInputs(renames=[SelectInput(old_name=s, new_name=s) for s in select])
917
+
918
+ raise ValueError(f"Unable to parse select input: {type(select)}")
919
+
920
+ def get_overlapping_columns(self) -> Set[str]:
921
+ """Finds column names that would conflict after the join."""
922
+ return self.left_manager.get_new_cols() & self.right_manager.get_new_cols()
923
+
924
+ def auto_generate_new_col_name(self, old_col_name: str, side: str) -> str:
925
+ """Generates a new, non-conflicting column name by adding a suffix if necessary."""
926
+ current_names = self.get_overlapping_columns()
927
+ if old_col_name not in current_names:
928
+ return old_col_name
929
+
930
+ new_name = old_col_name
931
+ while new_name in current_names:
932
+ new_name = f'{side}_{new_name}'
933
+ return new_name
934
+
935
+ def add_new_select_column(self, select_input: SelectInput, side: str) -> None:
936
+ """Adds a new column to the selection for either the left or right side."""
937
+ target_input = self.input.right_select if side == 'right' else self.input.left_select
938
+
939
+ select_input.new_name = self.auto_generate_new_col_name(
940
+ select_input.old_name, side=side
941
+ )
942
+
943
+ target_input.renames.append(select_input)
944
+
945
+
946
+ class CrossJoinInputManager(JoinSelectManagerMixin):
947
+ """Manager for cross join operations."""
948
+
949
+ def __init__(self, cross_join_input: CrossJoinInput):
950
+ self.input = deepcopy(cross_join_input)
951
+ self.left_manager = JoinInputsManager(self.input.left_select)
952
+ self.right_manager = JoinInputsManager(self.input.right_select)
953
+
954
+ @classmethod
955
+ def create(cls, left_select: Union[List[SelectInput], List[str]],
956
+ right_select: Union[List[SelectInput], List[str]]) -> "CrossJoinInputManager":
957
+ """Factory method to create CrossJoinInput from various input formats."""
958
+ left_inputs = cls.parse_select(left_select)
959
+ right_inputs = cls.parse_select(right_select)
960
+
961
+ cross_join = CrossJoinInput(
962
+ left_select=left_inputs,
963
+ right_select=right_inputs
964
+ )
965
+ return cls(cross_join)
966
+
967
+ def get_overlapping_records(self) -> Set[str]:
968
+ """Finds column names that would conflict after the join."""
969
+ return self.get_overlapping_columns()
970
+
971
+ def auto_rename(self, rename_mode: Literal["suffix", "prefix"] = "prefix") -> None:
972
+ """Automatically renames columns on the right side to prevent naming conflicts."""
973
+ overlapping_records = self.get_overlapping_records()
974
+
975
+ while len(overlapping_records) > 0:
976
+ for right_col in self.input.right_select.renames:
977
+ if right_col.new_name in overlapping_records:
978
+ if rename_mode == "prefix":
979
+ right_col.new_name = 'right_' + right_col.new_name
980
+ elif rename_mode == "suffix":
981
+ right_col.new_name = right_col.new_name + '_right'
982
+ else:
983
+ raise ValueError(f'Unknown rename_mode: {rename_mode}')
984
+ overlapping_records = self.get_overlapping_records()
985
+
986
+ # === Backward Compatibility Properties ===
987
+
988
+ @property
989
+ def left_select(self) -> JoinInputsManager:
990
+ """Backward compatibility: Access left_manager as left_select."""
991
+ return self.left_manager
992
+
993
+ @property
994
+ def right_select(self) -> JoinInputsManager:
995
+ """Backward compatibility: Access right_manager as right_select."""
996
+ return self.right_manager
997
+
998
+ @property
999
+ def overlapping_records(self) -> Set[str]:
1000
+ """Backward compatibility: Returns overlapping column names."""
1001
+ return self.get_overlapping_records()
1002
+
1003
+ def to_cross_join_input(self) -> CrossJoinInput:
1004
+ """Creates a new CrossJoinInput instance based on the current manager settings.
1005
+
1006
+ This is useful when you've modified the manager (e.g., via auto_rename) and
1007
+ want to get a fresh CrossJoinInput with all the current settings applied.
1008
+
1009
+ Returns:
1010
+ A new CrossJoinInput instance with current settings
1011
+ """
1012
+ return CrossJoinInput(
1013
+ left_select=JoinInputs(renames=self.input.left_select.renames.copy()),
1014
+ right_select=JoinInputs(renames=self.input.right_select.renames.copy())
1015
+ )
1016
+
1017
+
1018
+ class JoinInputManager(JoinSelectManagerMixin):
1019
+ """Manager for standard SQL-style join operations."""
1020
+
1021
+ def __init__(self, join_input: JoinInput):
1022
+ self.input = deepcopy(join_input)
1023
+ self.left_manager = JoinInputsManager(self.input.left_select)
1024
+ self.right_manager = JoinInputsManager(self.input.right_select)
1025
+ self.set_join_keys()
1026
+
1027
+ @classmethod
1028
+ def create(cls, join_mapping: Union[List[JoinMap], Tuple[str, str], str],
1029
+ left_select: Union[List[SelectInput], List[str]],
1030
+ right_select: Union[List[SelectInput], List[str]],
1031
+ how: JoinStrategy = 'inner') -> "JoinInputManager":
1032
+ """Factory method to create JoinInput from various input formats."""
1033
+ # Use JoinInput's own create method for parsing
1034
+ join_input = JoinInput(
1035
+ join_mapping=join_mapping,
1036
+ left_select=left_select,
1037
+ right_select=right_select,
1038
+ how=how
1039
+ )
1040
+
1041
+ manager = cls(join_input)
1042
+ manager.set_join_keys()
1043
+ return manager
1044
+
1045
+ def set_join_keys(self) -> None:
1046
+ """Marks the `SelectInput` objects corresponding to join keys."""
1047
+ left_join_keys = self._get_left_join_keys_set()
1048
+ right_join_keys = self._get_right_join_keys_set()
1049
+
1050
+ for select_input in self.input.left_select.renames:
1051
+ select_input.join_key = select_input.old_name in left_join_keys
1052
+
1053
+ for select_input in self.input.right_select.renames:
1054
+ select_input.join_key = select_input.old_name in right_join_keys
1055
+
1056
+ def _get_left_join_keys_set(self) -> Set[str]:
1057
+ """Internal: Returns a set of the left-side join key column names."""
1058
+ return {jm.left_col for jm in self.input.join_mapping}
1059
+
1060
+ def _get_right_join_keys_set(self) -> Set[str]:
1061
+ """Internal: Returns a set of the right-side join key column names."""
1062
+ return {jm.right_col for jm in self.input.join_mapping}
1063
+
1064
+ def get_left_join_keys(self) -> Set[str]:
1065
+ """Returns a set of the left-side join key column names."""
1066
+ return self._get_left_join_keys_set()
1067
+
1068
+ def get_right_join_keys(self) -> Set[str]:
1069
+ """Returns a set of the right-side join key column names."""
1070
+ return self._get_right_join_keys_set()
1071
+
1072
+ def get_left_join_keys_list(self) -> List[str]:
1073
+ """Returns an ordered list of the left-side join key column names."""
1074
+ return [jm.left_col for jm in self.used_join_mapping]
1075
+
1076
+ def get_right_join_keys_list(self) -> List[str]:
1077
+ """Returns an ordered list of the right-side join key column names."""
1078
+ return [jm.right_col for jm in self.used_join_mapping]
1079
+
1080
+ def get_overlapping_records(self) -> Set[str]:
1081
+ """Finds column names that would conflict after the join."""
1082
+ return self.get_overlapping_columns()
1083
+
1084
+ def auto_rename(self) -> None:
1085
+ """Automatically renames columns on the right side to prevent naming conflicts."""
1086
+ self.set_join_keys()
1087
+ overlapping_records = self.get_overlapping_records()
1088
+
1089
+ while len(overlapping_records) > 0:
1090
+ for right_col in self.input.right_select.renames:
1091
+ if right_col.new_name in overlapping_records:
1092
+ right_col.new_name = right_col.new_name + '_right'
1093
+ overlapping_records = self.get_overlapping_records()
1094
+
1095
+ def get_join_key_renames(self, filter_drop: bool = False) -> FullJoinKeyResponse:
1096
+ """Gets the temporary rename mappings for the join keys on both sides."""
1097
+ left_renames = self.left_manager.get_join_key_renames(side="left", filter_drop=filter_drop)
1098
+ right_renames = self.right_manager.get_join_key_renames(side="right", filter_drop=filter_drop)
1099
+ return FullJoinKeyResponse(left_renames, right_renames)
1100
+
1101
+ def get_names_for_table_rename(self) -> List[JoinMap]:
1102
+ """Gets join mapping with renamed columns applied."""
1103
+ new_mappings: List[JoinMap] = []
1104
+ left_rename_table = self.left_manager.get_rename_table()
1105
+ right_rename_table = self.right_manager.get_rename_table()
1106
+
1107
+ for join_map in self.input.join_mapping:
1108
+ new_left = left_rename_table.get(join_map.left_col, join_map.left_col)
1109
+ new_right = right_rename_table.get(join_map.right_col, join_map.right_col)
1110
+ new_mappings.append(JoinMap(left_col=new_left, right_col=new_right))
1111
+
1112
+ return new_mappings
1113
+
1114
+ def get_used_join_mapping(self) -> List[JoinMap]:
1115
+ """Returns the final join mapping after applying all renames and transformations."""
1116
+ new_mappings: List[JoinMap] = []
1117
+ left_rename_table = self.left_manager.get_rename_table()
1118
+ right_rename_table = self.right_manager.get_rename_table()
1119
+ left_join_rename_mapping = self.left_manager.get_join_key_rename_mapping("left")
1120
+ right_join_rename_mapping = self.right_manager.get_join_key_rename_mapping("right")
1121
+ for join_map in self.input.join_mapping:
1122
+ left_col = left_rename_table.get(join_map.left_col, join_map.left_col)
1123
+ right_col = right_rename_table.get(join_map.right_col, join_map.left_col)
1124
+
1125
+ final_left = left_join_rename_mapping.get(left_col, None)
1126
+ final_right = right_join_rename_mapping.get(right_col, None)
1127
+
1128
+ new_mappings.append(JoinMap(left_col=final_left, right_col=final_right))
1129
+
1130
+ return new_mappings
1131
+
1132
+ def to_join_input(self) -> JoinInput:
1133
+ """Creates a new JoinInput instance based on the current manager settings.
1134
+
1135
+ This is useful when you've modified the manager (e.g., via auto_rename) and
1136
+ want to get a fresh JoinInput with all the current settings applied.
1137
+
1138
+ Returns:
1139
+ A new JoinInput instance with current settings
1140
+ """
1141
+ return JoinInput(
1142
+ join_mapping=self.input.join_mapping,
1143
+ left_select=JoinInputs(renames=self.input.left_select.renames.copy()),
1144
+ right_select=JoinInputs(renames=self.input.right_select.renames.copy()),
1145
+ how=self.input.how
1146
+ )
1147
+
1148
+ @property
1149
+ def left_select(self) -> JoinInputsManager:
1150
+ """Backward compatibility: Access left_manager as left_select.
1151
+
1152
+ This returns the MANAGER, not the data model.
1153
+ Usage: manager.left_select.join_key_selects
1154
+ """
1155
+ return self.left_manager
1156
+
1157
+ @property
1158
+ def right_select(self) -> JoinInputsManager:
1159
+ """Backward compatibility: Access right_manager as right_select.
1160
+
1161
+ This returns the MANAGER, not the data model.
1162
+ Usage: manager.right_select.join_key_selects
1163
+ """
1164
+ return self.right_manager
1165
+
1166
+ @property
1167
+ def how(self) -> JoinStrategy:
1168
+ """Backward compatibility: Access join strategy."""
1169
+ return self.input.how
1170
+
1171
+ @property
1172
+ def join_mapping(self) -> List[JoinMap]:
1173
+ """Backward compatibility: Access join mapping."""
1174
+ return self.input.join_mapping
1175
+
1176
+ @property
1177
+ def overlapping_records(self) -> Set[str]:
1178
+ """Backward compatibility: Returns overlapping column names."""
1179
+ return self.get_overlapping_records()
1180
+
1181
+ @property
1182
+ def used_join_mapping(self) -> List[JoinMap]:
1183
+ """Backward compatibility: Returns used join mapping.
1184
+
1185
+ This property is critical - it's used by left_join_keys and right_join_keys.
1186
+ """
1187
+ return self.get_used_join_mapping()
1188
+
1189
+ @property
1190
+ def left_join_keys(self) -> List[str]:
1191
+ """Backward compatibility: Returns left join keys list.
1192
+
1193
+ IMPORTANT: Uses the used_join_mapping PROPERTY (not method).
1194
+ """
1195
+ return [jm.left_col for jm in self.used_join_mapping]
1196
+
1197
+ @property
1198
+ def right_join_keys(self) -> List[str]:
1199
+ """Backward compatibility: Returns right join keys list.
1200
+
1201
+ IMPORTANT: Uses the used_join_mapping PROPERTY (not method).
1202
+ """
1203
+ return [jm.right_col for jm in self.used_join_mapping]
1204
+
1205
+ @property
1206
+ def _left_join_keys(self) -> Set[str]:
1207
+ """Backward compatibility: Private property for left join key set."""
1208
+ return self._get_left_join_keys_set()
1209
+
1210
+ @property
1211
+ def _right_join_keys(self) -> Set[str]:
1212
+ """Backward compatibility: Private property for right join key set."""
1213
+ return self._get_right_join_keys_set()
1214
+
1215
+
1216
+ class FuzzyMatchInputManager(JoinInputManager):
1217
+ """Manager for fuzzy matching join operations."""
1218
+
1219
+ def __init__(self, fuzzy_input: FuzzyMatchInput):
1220
+ self.fuzzy_input = deepcopy(fuzzy_input)
1221
+ super().__init__(JoinInput(
1222
+ join_mapping=[JoinMap(left_col=fm.left_col, right_col=fm.right_col)
1223
+ for fm in self.fuzzy_input.join_mapping],
1224
+ left_select=self.fuzzy_input.left_select,
1225
+ right_select=self.fuzzy_input.right_select,
1226
+ how=self.fuzzy_input.how
1227
+ ))
1228
+
1229
+ @classmethod
1230
+ def create(cls, join_mapping: Union[List[FuzzyMapping], Tuple[str, str], str],
1231
+ left_select: Union[List[SelectInput], List[str]],
1232
+ right_select: Union[List[SelectInput], List[str]],
1233
+ aggregate_output: bool = False,
1234
+ how: JoinStrategy = 'inner') -> "FuzzyMatchInputManager":
1235
+ """Factory method to create FuzzyMatchInput from various input formats."""
1236
+ parsed_mapping = cls.parse_fuzz_mapping(join_mapping)
1237
+ left_inputs = cls.parse_select(left_select)
1238
+ right_inputs = cls.parse_select(right_select)
1239
+
1240
+ fuzzy_input = FuzzyMatchInput(
1241
+ join_mapping=parsed_mapping,
1242
+ left_select=left_inputs,
1243
+ right_select=right_inputs,
1244
+ how=how,
1245
+ aggregate_output=aggregate_output
1246
+ )
1247
+
1248
+ manager = cls(fuzzy_input)
1249
+
1250
+ right_old_names = {v.old_name for v in fuzzy_input.right_select.renames}
1251
+ left_old_names = {v.old_name for v in fuzzy_input.left_select.renames}
1252
+
1253
+ for jm in parsed_mapping:
1254
+ if jm.right_col not in right_old_names:
1255
+ manager.right_manager.append(
1256
+ SelectInput(old_name=jm.right_col, keep=False, join_key=True)
1257
+ )
1258
+ if jm.left_col not in left_old_names:
1259
+ manager.left_manager.append(
1260
+ SelectInput(old_name=jm.left_col, keep=False, join_key=True)
1261
+ )
1262
+
1263
+ manager.set_join_keys()
1264
+ return manager
1265
+
1266
+ @staticmethod
1267
+ def parse_fuzz_mapping(fuzz_mapping: Union[List[FuzzyMapping], Tuple[str, str],
1268
+ str, FuzzyMapping, List[Dict]]) -> List[FuzzyMapping]:
1269
+ """Parses various input formats into a list of FuzzyMapping objects."""
1270
+ if isinstance(fuzz_mapping, (tuple, list)):
1271
+ if len(fuzz_mapping) == 0:
1272
+ raise ValueError("Fuzzy mapping cannot be empty")
1273
+
1274
+ if all(isinstance(fm, dict) for fm in fuzz_mapping):
1275
+ return [FuzzyMapping(**fm) for fm in fuzz_mapping]
1276
+
1277
+ if all(isinstance(fm, FuzzyMapping) for fm in fuzz_mapping):
1278
+ return fuzz_mapping
1279
+
1280
+ if len(fuzz_mapping) <= 2:
1281
+ if len(fuzz_mapping) == 2:
1282
+ if isinstance(fuzz_mapping[0], str) and isinstance(fuzz_mapping[1], str):
1283
+ return [FuzzyMapping(left_col=fuzz_mapping[0], right_col=fuzz_mapping[1])]
1284
+ elif len(fuzz_mapping) == 1 and isinstance(fuzz_mapping[0], str):
1285
+ return [FuzzyMapping(left_col=fuzz_mapping[0], right_col=fuzz_mapping[0])]
1286
+
1287
+ elif isinstance(fuzz_mapping, str):
1288
+ return [FuzzyMapping(left_col=fuzz_mapping, right_col=fuzz_mapping)]
1289
+
1290
+ elif isinstance(fuzz_mapping, FuzzyMapping):
1291
+ return [fuzz_mapping]
1292
+
1293
+ raise ValueError(f'No valid fuzzy mapping as input: {type(fuzz_mapping)}')
1294
+
1295
+ def get_fuzzy_maps(self) -> List[FuzzyMapping]:
1296
+ """Returns the final fuzzy mappings after applying all column renames."""
1297
+ new_mappings = []
1298
+ left_rename_table = self.left_manager.get_rename_table()
1299
+ right_rename_table = self.right_manager.get_rename_table()
1300
+
1301
+ for org_fuzzy_map in self.fuzzy_input.join_mapping:
1302
+ right_col = right_rename_table.get(org_fuzzy_map.right_col, org_fuzzy_map.right_col)
1303
+ left_col = left_rename_table.get(org_fuzzy_map.left_col, org_fuzzy_map.left_col)
1304
+
1305
+ if right_col != org_fuzzy_map.right_col or left_col != org_fuzzy_map.left_col:
1306
+ new_mapping = deepcopy(org_fuzzy_map)
1307
+ new_mapping.left_col = left_col
1308
+ new_mapping.right_col = right_col
1309
+ new_mappings.append(new_mapping)
1310
+ else:
1311
+ new_mappings.append(org_fuzzy_map)
1312
+
1313
+ return new_mappings
1314
+
1315
+ # === Backward Compatibility Properties ===
1316
+
1317
+ @property
1318
+ def fuzzy_maps(self) -> List[FuzzyMapping]:
1319
+ """Backward compatibility: Returns fuzzy mappings."""
1320
+ return self.get_fuzzy_maps()
1321
+
1322
+ @property
1323
+ def join_mapping(self) -> List[FuzzyMapping]:
1324
+ """Backward compatibility: Access fuzzy join mapping."""
1325
+ return self.get_fuzzy_maps()
1326
+
1327
+ @property
1328
+ def aggregate_output(self) -> bool:
1329
+ """Backward compatibility: Access aggregate_output setting."""
1330
+ return self.fuzzy_input.aggregate_output
1331
+
1332
+ def to_fuzzy_match_input(self) -> FuzzyMatchInput:
1333
+ """Creates a new FuzzyMatchInput instance based on the current manager settings.
1334
+
1335
+ This is useful when you've modified the manager (e.g., via auto_rename) and
1336
+ want to get a fresh FuzzyMatchInput with all the current settings applied.
1337
+
1338
+ Returns:
1339
+ A new FuzzyMatchInput instance with current settings
1340
+ """
1341
+ return FuzzyMatchInput(
1342
+ join_mapping=self.fuzzy_input.join_mapping,
1343
+ left_select=JoinInputs(renames=self.input.left_select.renames.copy()),
1344
+ right_select=JoinInputs(renames=self.input.right_select.renames.copy()),
1345
+ how=self.fuzzy_input.how,
1346
+ aggregate_output=self.fuzzy_input.aggregate_output
1347
+ )