Flowfile 0.3.9__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (201) hide show
  1. flowfile/__init__.py +8 -1
  2. flowfile/api.py +1 -3
  3. flowfile/web/static/assets/{CloudConnectionManager-c97c25f8.js → CloudConnectionManager-0dfba9f2.js} +2 -2
  4. flowfile/web/static/assets/{CloudStorageReader-f1ff509e.js → CloudStorageReader-d5b1b6c9.js} +11 -78
  5. flowfile/web/static/assets/{CloudStorageWriter-034f8b78.js → CloudStorageWriter-00d87aad.js} +12 -79
  6. flowfile/web/static/assets/{CloudStorageWriter-49c9a4b2.css → CloudStorageWriter-b0ee067f.css} +24 -24
  7. flowfile/web/static/assets/ColumnSelector-4685e75d.js +83 -0
  8. flowfile/web/static/assets/ColumnSelector-47996a16.css +10 -0
  9. flowfile/web/static/assets/ContextMenu-23e909da.js +41 -0
  10. flowfile/web/static/assets/{SettingsSection-9c836ecc.css → ContextMenu-4c74eef1.css} +0 -21
  11. flowfile/web/static/assets/ContextMenu-63cfa99b.css +26 -0
  12. flowfile/web/static/assets/ContextMenu-70ae0c79.js +41 -0
  13. flowfile/web/static/assets/ContextMenu-c13f91d0.css +26 -0
  14. flowfile/web/static/assets/ContextMenu-f149cf7c.js +41 -0
  15. flowfile/web/static/assets/{CrossJoin-41efa4cb.css → CrossJoin-1119d18e.css} +18 -18
  16. flowfile/web/static/assets/{CrossJoin-9e156ebe.js → CrossJoin-702a3edd.js} +14 -84
  17. flowfile/web/static/assets/CustomNode-74a37f74.css +32 -0
  18. flowfile/web/static/assets/CustomNode-b1519993.js +211 -0
  19. flowfile/web/static/assets/{DatabaseConnectionSettings-d5c625b3.js → DatabaseConnectionSettings-6f3e4ea5.js} +3 -3
  20. flowfile/web/static/assets/{DatabaseManager-265adc5e.js → DatabaseManager-cf5ef661.js} +2 -2
  21. flowfile/web/static/assets/{DatabaseReader-f50c6558.css → DatabaseReader-ae61773c.css} +0 -27
  22. flowfile/web/static/assets/{DatabaseReader-0b10551e.js → DatabaseReader-d38c7295.js} +14 -114
  23. flowfile/web/static/assets/{DatabaseWriter-c17c6916.js → DatabaseWriter-b04ef46a.js} +13 -74
  24. flowfile/web/static/assets/{ExploreData-5bdae813.css → ExploreData-2d0cf4db.css} +8 -14
  25. flowfile/web/static/assets/ExploreData-5fa10ed8.js +192 -0
  26. flowfile/web/static/assets/{ExternalSource-3a66556c.js → ExternalSource-d39af878.js} +8 -79
  27. flowfile/web/static/assets/{Filter-91ad87e7.js → Filter-9b6d08db.js} +12 -85
  28. flowfile/web/static/assets/{Filter-a9d08ba1.css → Filter-f62091b3.css} +3 -3
  29. flowfile/web/static/assets/{Formula-3c395ab1.js → Formula-6b04fb1d.js} +20 -87
  30. flowfile/web/static/assets/{Formula-29f19d21.css → Formula-bb96803d.css} +4 -4
  31. flowfile/web/static/assets/{FuzzyMatch-6857de82.css → FuzzyMatch-1010f966.css} +42 -42
  32. flowfile/web/static/assets/{FuzzyMatch-2df0d230.js → FuzzyMatch-999521f4.js} +16 -87
  33. flowfile/web/static/assets/{GraphSolver-d285877f.js → GraphSolver-17dd2198.js} +13 -159
  34. flowfile/web/static/assets/GraphSolver-f0cb7bfb.css +22 -0
  35. flowfile/web/static/assets/{GroupBy-0bd1cc6b.js → GroupBy-6b039e18.js} +12 -75
  36. flowfile/web/static/assets/{Unique-b5615727.css → GroupBy-b9505323.css} +8 -8
  37. flowfile/web/static/assets/{Join-5a78a203.js → Join-24d0f113.js} +15 -85
  38. flowfile/web/static/assets/{Join-f45eff22.css → Join-fd79b451.css} +20 -20
  39. flowfile/web/static/assets/{ManualInput-a71b52c6.css → ManualInput-3246a08d.css} +20 -20
  40. flowfile/web/static/assets/{ManualInput-93aef9d6.js → ManualInput-34639209.js} +11 -82
  41. flowfile/web/static/assets/MultiSelect-0e8724a3.js +5 -0
  42. flowfile/web/static/assets/MultiSelect.vue_vue_type_script_setup_true_lang-b0e538c2.js +63 -0
  43. flowfile/web/static/assets/NumericInput-3d63a470.js +5 -0
  44. flowfile/web/static/assets/NumericInput.vue_vue_type_script_setup_true_lang-e0edeccc.js +35 -0
  45. flowfile/web/static/assets/Output-283fe388.css +37 -0
  46. flowfile/web/static/assets/{Output-411ecaee.js → Output-edea9802.js} +62 -273
  47. flowfile/web/static/assets/{Pivot-89db4b04.js → Pivot-61d19301.js} +14 -138
  48. flowfile/web/static/assets/Pivot-cf333e3d.css +22 -0
  49. flowfile/web/static/assets/PivotValidation-891ddfb0.css +13 -0
  50. flowfile/web/static/assets/PivotValidation-c46cd420.css +13 -0
  51. flowfile/web/static/assets/PivotValidation-de9f43fe.js +61 -0
  52. flowfile/web/static/assets/PivotValidation-f97fec5b.js +61 -0
  53. flowfile/web/static/assets/{PolarsCode-a9f974f8.js → PolarsCode-bc3c9984.js} +13 -80
  54. flowfile/web/static/assets/Read-64a3f259.js +218 -0
  55. flowfile/web/static/assets/Read-e808b239.css +62 -0
  56. flowfile/web/static/assets/RecordCount-3d5039be.js +53 -0
  57. flowfile/web/static/assets/{RecordId-55ae7d36.js → RecordId-597510e0.js} +8 -80
  58. flowfile/web/static/assets/SQLQueryComponent-36cef432.css +27 -0
  59. flowfile/web/static/assets/SQLQueryComponent-df51adbe.js +38 -0
  60. flowfile/web/static/assets/{Sample-b4a18476.js → Sample-4be0a507.js} +8 -77
  61. flowfile/web/static/assets/{SecretManager-b066d13a.js → SecretManager-4839be57.js} +2 -2
  62. flowfile/web/static/assets/{Select-727688dc.js → Select-9b72f201.js} +11 -85
  63. flowfile/web/static/assets/SettingsSection-2e4d03c4.css +21 -0
  64. flowfile/web/static/assets/SettingsSection-5c696bee.css +20 -0
  65. flowfile/web/static/assets/SettingsSection-71e6b7e3.css +21 -0
  66. flowfile/web/static/assets/SettingsSection-7ded385d.js +45 -0
  67. flowfile/web/static/assets/{SettingsSection-695ac487.js → SettingsSection-e1e9c953.js} +2 -40
  68. flowfile/web/static/assets/SettingsSection-f0f75a42.js +53 -0
  69. flowfile/web/static/assets/SingleSelect-6c777aac.js +5 -0
  70. flowfile/web/static/assets/SingleSelect.vue_vue_type_script_setup_true_lang-33e3ff9b.js +62 -0
  71. flowfile/web/static/assets/SliderInput-7cb93e62.js +40 -0
  72. flowfile/web/static/assets/SliderInput-b8fb6a8c.css +4 -0
  73. flowfile/web/static/assets/{GroupBy-ab1ea74b.css → Sort-3643d625.css} +8 -8
  74. flowfile/web/static/assets/{Sort-be3339a8.js → Sort-6cbde21a.js} +12 -97
  75. flowfile/web/static/assets/TextInput-d9a40c11.js +5 -0
  76. flowfile/web/static/assets/TextInput.vue_vue_type_script_setup_true_lang-5896c375.js +32 -0
  77. flowfile/web/static/assets/{TextToRows-c92d1ec2.css → TextToRows-5d2c1190.css} +9 -9
  78. flowfile/web/static/assets/{TextToRows-7b8998da.js → TextToRows-c4fcbf4d.js} +14 -83
  79. flowfile/web/static/assets/ToggleSwitch-4ef91d19.js +5 -0
  80. flowfile/web/static/assets/ToggleSwitch.vue_vue_type_script_setup_true_lang-38478c20.js +31 -0
  81. flowfile/web/static/assets/{UnavailableFields-8b0cb48e.js → UnavailableFields-a03f512c.js} +2 -2
  82. flowfile/web/static/assets/{Union-8d9ac7f9.css → Union-af6c3d9b.css} +6 -6
  83. flowfile/web/static/assets/Union-bfe9b996.js +77 -0
  84. flowfile/web/static/assets/{Unique-af5a80b4.js → Unique-5d023a27.js} +23 -104
  85. flowfile/web/static/assets/{Sort-7ccfa0fe.css → Unique-f9fb0809.css} +8 -8
  86. flowfile/web/static/assets/Unpivot-1e422df3.css +30 -0
  87. flowfile/web/static/assets/{Unpivot-5195d411.js → Unpivot-91cc5354.js} +12 -166
  88. flowfile/web/static/assets/UnpivotValidation-0d240eeb.css +13 -0
  89. flowfile/web/static/assets/UnpivotValidation-7ee2de44.js +51 -0
  90. flowfile/web/static/assets/{ExploreData-18a4fe52.js → VueGraphicWalker-e51b9924.js} +4 -264
  91. flowfile/web/static/assets/VueGraphicWalker-ed5ab88b.css +6 -0
  92. flowfile/web/static/assets/{api-cb00cce6.js → api-c1bad5ca.js} +1 -1
  93. flowfile/web/static/assets/{api-023d1733.js → api-cf1221f0.js} +1 -1
  94. flowfile/web/static/assets/{designer-2197d782.css → designer-8da3ba3a.css} +859 -201
  95. flowfile/web/static/assets/{designer-6c322d8e.js → designer-9633482a.js} +2297 -733
  96. flowfile/web/static/assets/{documentation-4d1fafe1.js → documentation-ca400224.js} +1 -1
  97. flowfile/web/static/assets/{dropDown-0b46dd77.js → dropDown-614b998d.js} +1 -1
  98. flowfile/web/static/assets/{fullEditor-ec4e4f95.js → fullEditor-f7971590.js} +2 -2
  99. flowfile/web/static/assets/{genericNodeSettings-def5879b.js → genericNodeSettings-4fe5f36b.js} +3 -3
  100. flowfile/web/static/assets/{index-681a3ed0.css → index-50508d4d.css} +8 -0
  101. flowfile/web/static/assets/{index-683fc198.js → index-5429bbf8.js} +208 -31
  102. flowfile/web/static/assets/nodeInput-5d0d6b79.js +41 -0
  103. flowfile/web/static/assets/outputCsv-076b85ab.js +86 -0
  104. flowfile/web/static/assets/{Output-48f81019.css → outputCsv-9cc59e0b.css} +0 -143
  105. flowfile/web/static/assets/outputExcel-0fd17dbe.js +56 -0
  106. flowfile/web/static/assets/outputExcel-b41305c0.css +102 -0
  107. flowfile/web/static/assets/outputParquet-b61e0847.js +31 -0
  108. flowfile/web/static/assets/outputParquet-cf8cf3f2.css +4 -0
  109. flowfile/web/static/assets/readCsv-a8bb8b61.js +179 -0
  110. flowfile/web/static/assets/readCsv-c767cb37.css +52 -0
  111. flowfile/web/static/assets/readExcel-67b4aee0.js +201 -0
  112. flowfile/web/static/assets/readExcel-806d2826.css +64 -0
  113. flowfile/web/static/assets/readParquet-48c81530.css +19 -0
  114. flowfile/web/static/assets/readParquet-92ce1dbc.js +23 -0
  115. flowfile/web/static/assets/{secretApi-baceb6f9.js → secretApi-68435402.js} +1 -1
  116. flowfile/web/static/assets/{selectDynamic-de91449a.js → selectDynamic-92e25ee3.js} +7 -7
  117. flowfile/web/static/assets/{selectDynamic-b062bc9b.css → selectDynamic-aa913ff4.css} +16 -16
  118. flowfile/web/static/assets/user-defined-icon-0ae16c90.png +0 -0
  119. flowfile/web/static/assets/{vue-codemirror.esm-dc5e3348.js → vue-codemirror.esm-41b0e0d7.js} +65 -36
  120. flowfile/web/static/assets/{vue-content-loader.es-ba94b82f.js → vue-content-loader.es-2c8e608f.js} +1 -1
  121. flowfile/web/static/index.html +2 -2
  122. {flowfile-0.3.9.dist-info → flowfile-0.5.1.dist-info}/METADATA +5 -3
  123. {flowfile-0.3.9.dist-info → flowfile-0.5.1.dist-info}/RECORD +191 -121
  124. {flowfile-0.3.9.dist-info → flowfile-0.5.1.dist-info}/WHEEL +1 -1
  125. {flowfile-0.3.9.dist-info → flowfile-0.5.1.dist-info}/entry_points.txt +1 -0
  126. flowfile_core/__init__.py +3 -0
  127. flowfile_core/configs/flow_logger.py +5 -13
  128. flowfile_core/configs/node_store/__init__.py +30 -0
  129. flowfile_core/configs/node_store/nodes.py +383 -99
  130. flowfile_core/configs/node_store/user_defined_node_registry.py +193 -0
  131. flowfile_core/configs/settings.py +2 -1
  132. flowfile_core/database/connection.py +5 -21
  133. flowfile_core/fileExplorer/funcs.py +239 -121
  134. flowfile_core/flowfile/analytics/analytics_processor.py +1 -0
  135. flowfile_core/flowfile/code_generator/code_generator.py +62 -64
  136. flowfile_core/flowfile/flow_data_engine/create/funcs.py +73 -56
  137. flowfile_core/flowfile/flow_data_engine/flow_data_engine.py +77 -86
  138. flowfile_core/flowfile/flow_data_engine/flow_file_column/interface.py +4 -0
  139. flowfile_core/flowfile/flow_data_engine/flow_file_column/main.py +19 -34
  140. flowfile_core/flowfile/flow_data_engine/flow_file_column/type_registry.py +36 -0
  141. flowfile_core/flowfile/flow_data_engine/fuzzy_matching/prepare_for_fuzzy_match.py +23 -23
  142. flowfile_core/flowfile/flow_data_engine/join/utils.py +1 -1
  143. flowfile_core/flowfile/flow_data_engine/join/verify_integrity.py +9 -4
  144. flowfile_core/flowfile/flow_data_engine/subprocess_operations/subprocess_operations.py +212 -86
  145. flowfile_core/flowfile/flow_data_engine/utils.py +2 -0
  146. flowfile_core/flowfile/flow_graph.py +240 -54
  147. flowfile_core/flowfile/flow_node/flow_node.py +48 -13
  148. flowfile_core/flowfile/flow_node/models.py +2 -1
  149. flowfile_core/flowfile/handler.py +24 -5
  150. flowfile_core/flowfile/manage/compatibility_enhancements.py +404 -41
  151. flowfile_core/flowfile/manage/io_flowfile.py +394 -0
  152. flowfile_core/flowfile/node_designer/__init__.py +47 -0
  153. flowfile_core/flowfile/node_designer/_type_registry.py +197 -0
  154. flowfile_core/flowfile/node_designer/custom_node.py +371 -0
  155. flowfile_core/flowfile/node_designer/ui_components.py +277 -0
  156. flowfile_core/flowfile/schema_callbacks.py +17 -10
  157. flowfile_core/flowfile/setting_generator/settings.py +15 -10
  158. flowfile_core/main.py +5 -1
  159. flowfile_core/routes/routes.py +73 -30
  160. flowfile_core/routes/user_defined_components.py +55 -0
  161. flowfile_core/schemas/cloud_storage_schemas.py +0 -2
  162. flowfile_core/schemas/input_schema.py +228 -65
  163. flowfile_core/schemas/output_model.py +5 -2
  164. flowfile_core/schemas/schemas.py +153 -35
  165. flowfile_core/schemas/transform_schema.py +1083 -412
  166. flowfile_core/schemas/yaml_types.py +103 -0
  167. flowfile_core/types.py +156 -0
  168. flowfile_core/utils/validate_setup.py +3 -1
  169. flowfile_frame/__init__.py +3 -1
  170. flowfile_frame/flow_frame.py +31 -24
  171. flowfile_frame/flow_frame_methods.py +12 -9
  172. flowfile_worker/__init__.py +9 -35
  173. flowfile_worker/create/__init__.py +3 -21
  174. flowfile_worker/create/funcs.py +68 -56
  175. flowfile_worker/create/models.py +130 -62
  176. flowfile_worker/main.py +5 -2
  177. flowfile_worker/routes.py +52 -13
  178. shared/__init__.py +15 -0
  179. shared/storage_config.py +258 -0
  180. tools/migrate/README.md +56 -0
  181. tools/migrate/__init__.py +12 -0
  182. tools/migrate/__main__.py +131 -0
  183. tools/migrate/legacy_schemas.py +621 -0
  184. tools/migrate/migrate.py +598 -0
  185. tools/migrate/tests/__init__.py +0 -0
  186. tools/migrate/tests/conftest.py +23 -0
  187. tools/migrate/tests/test_migrate.py +627 -0
  188. tools/migrate/tests/test_migration_e2e.py +1010 -0
  189. tools/migrate/tests/test_node_migrations.py +813 -0
  190. flowfile/web/static/assets/GraphSolver-17fd26db.css +0 -68
  191. flowfile/web/static/assets/Pivot-f415e85f.css +0 -35
  192. flowfile/web/static/assets/Read-80dc1675.css +0 -197
  193. flowfile/web/static/assets/Read-c3b1929c.js +0 -701
  194. flowfile/web/static/assets/RecordCount-4e95f98e.js +0 -122
  195. flowfile/web/static/assets/Union-89fd73dc.js +0 -146
  196. flowfile/web/static/assets/Unpivot-246e9bbd.css +0 -77
  197. flowfile/web/static/assets/nodeTitle-a16db7c3.js +0 -227
  198. flowfile/web/static/assets/nodeTitle-f4b12bcb.css +0 -134
  199. flowfile_core/flowfile/manage/open_flowfile.py +0 -135
  200. {flowfile-0.3.9.dist-info → flowfile-0.5.1.dist-info/licenses}/LICENSE +0 -0
  201. /flowfile_core/flowfile/manage/manage_flowfile.py → /tools/__init__.py +0 -0
@@ -1,12 +1,21 @@
1
1
  from typing import List, Dict, Tuple, Set, Optional, Literal, Callable
2
- from dataclasses import dataclass, field
2
+ from dataclasses import asdict
3
3
  import polars as pl
4
4
  from polars import selectors
5
5
  from copy import deepcopy
6
+ from pydantic import BaseModel, ConfigDict, model_validator, Field
7
+ from typing import NamedTuple, Union, Any
8
+ from flowfile_core.schemas.yaml_types import (
9
+ SelectInputYaml, JoinInputsYaml, JoinInputYaml,
10
+ CrossJoinInputYaml, FuzzyMatchInputYaml
11
+ )
12
+ from pl_fuzzy_frame_match.models import FuzzyMapping
6
13
 
7
- from typing import NamedTuple
14
+ from flowfile_core.types import DataType, DataTypeStr
8
15
 
9
- from pl_fuzzy_frame_match.models import FuzzyMapping
16
+ FuzzyMap = FuzzyMapping
17
+
18
+ AUTO_DATA_TYPE = "Auto"
10
19
 
11
20
 
12
21
  def get_func_type_mapping(func: str):
@@ -54,436 +63,489 @@ class FullJoinKeyResponse(NamedTuple):
54
63
  right: JoinKeyRenameResponse
55
64
 
56
65
 
57
- @dataclass
58
- class SelectInput:
66
+ class SelectInput(BaseModel):
59
67
  """Defines how a single column should be selected, renamed, or type-cast.
60
68
 
61
69
  This is a core building block for any operation that involves column manipulation.
62
70
  It holds all the configuration for a single field in a selection operation.
63
71
  """
72
+ model_config = ConfigDict(frozen=False)
73
+
64
74
  old_name: str
65
75
  original_position: Optional[int] = None
66
76
  new_name: Optional[str] = None
67
77
  data_type: Optional[str] = None
68
- data_type_change: Optional[bool] = False
69
- join_key: Optional[bool] = False
70
- is_altered: Optional[bool] = False
78
+ data_type_change: bool = False
79
+ join_key: bool = False
80
+ is_altered: bool = False
71
81
  position: Optional[int] = None
72
- is_available: Optional[bool] = True
73
- keep: Optional[bool] = True
82
+ is_available: bool = True
83
+ keep: bool = True
84
+
85
+ def __init__(self, old_name: str = None, new_name: str = None, **data):
86
+ if old_name is not None:
87
+ data['old_name'] = old_name
88
+ if new_name is not None:
89
+ data['new_name'] = new_name
90
+ super().__init__(**data)
91
+
92
+ def to_yaml_dict(self) -> SelectInputYaml:
93
+ """Serialize for YAML output - only user-relevant fields."""
94
+ result: SelectInputYaml = {"old_name": self.old_name}
95
+ if self.new_name != self.old_name:
96
+ result["new_name"] = self.new_name
97
+ if not self.keep:
98
+ result["keep"] = self.keep
99
+ if self.data_type_change and self.data_type:
100
+ result["data_type"] = self.data_type
101
+ return result
102
+
103
+ @classmethod
104
+ def from_yaml_dict(cls, data: dict) -> "SelectInput":
105
+ """Load from slim YAML format."""
106
+ old_name = data["old_name"]
107
+ new_name = data.get("new_name", old_name)
108
+ return cls(
109
+ old_name=old_name,
110
+ new_name=new_name,
111
+ keep=data.get("keep", True),
112
+ data_type=data.get("data_type"),
113
+ data_type_change=data.get("data_type") is not None,
114
+ is_altered=old_name != new_name,
115
+ )
116
+
117
+ @model_validator(mode='after')
118
+ def set_default_new_name(self):
119
+ """If new_name is None, default it to old_name."""
120
+ if self.new_name is None:
121
+ self.new_name = self.old_name
122
+ if self.old_name != self.new_name:
123
+ self.is_altered = True
124
+ return self
74
125
 
75
126
  def __hash__(self):
127
+ """Allow SelectInput to be used in sets and as dict keys."""
76
128
  return hash(self.old_name)
77
129
 
78
- def __init__(self, old_name: str, new_name: str = None, keep: bool = True, data_type: str = None,
79
- data_type_change: bool = False, join_key: bool = False, is_altered: bool = False,
80
- is_available: bool = True, position: int = None):
81
- self.old_name = old_name
82
- if new_name is None:
83
- new_name = old_name
84
- self.new_name = new_name
85
- self.keep = keep
86
- self.data_type = data_type
87
- self.data_type_change = data_type_change
88
- self.join_key = join_key
89
- self.is_altered = is_altered
90
- self.is_available = is_available
91
- self.position = position
130
+ def __eq__(self, other):
131
+ """Required when implementing __hash__."""
132
+ if not isinstance(other, SelectInput):
133
+ return False
134
+ return self.old_name == other.old_name
92
135
 
93
136
  @property
94
137
  def polars_type(self) -> str:
95
138
  """Translates a user-friendly type name to a Polars data type string."""
96
- if self.data_type.lower() == 'string':
139
+ data_type_lower = self.data_type.lower()
140
+ if data_type_lower == 'string':
97
141
  return 'Utf8'
98
- elif self.data_type.lower() == 'integer':
142
+ elif data_type_lower == 'integer':
99
143
  return 'Int64'
100
- elif self.data_type.lower() == 'double':
144
+ elif data_type_lower == 'double':
101
145
  return 'Float64'
102
146
  return self.data_type
103
147
 
104
148
 
105
- @dataclass
106
- class FieldInput:
149
+ class FieldInput(BaseModel):
107
150
  """Represents a single field with its name and data type, typically for defining an output column."""
108
151
  name: str
109
- data_type: Optional[str] = None
110
-
111
- def __init__(self, name: str, data_type: str = None):
112
- self.name = name
113
- self.data_type = data_type
152
+ data_type: DataType | Literal["Auto"] | DataTypeStr | None = AUTO_DATA_TYPE
114
153
 
115
154
 
116
- @dataclass
117
- class FunctionInput:
155
+ class FunctionInput(BaseModel):
118
156
  """Defines a formula to be applied, including the output field information."""
119
157
  field: FieldInput
120
158
  function: str
121
159
 
160
+ def __init__(self, field: FieldInput = None, function: str = None, **data):
161
+ if field is not None:
162
+ data['field'] = field
163
+ if function is not None:
164
+ data['function'] = function
165
+ super().__init__(**data)
122
166
 
123
- @dataclass
124
- class BasicFilter:
125
- """Defines a simple, single-condition filter (e.g., 'column' 'equals' 'value')."""
126
- field: str = ''
127
- filter_type: str = ''
128
- filter_value: str = ''
129
-
130
-
131
- @dataclass
132
- class FilterInput:
133
- """Defines the settings for a filter operation, supporting basic or advanced (expression-based) modes."""
134
- advanced_filter: str = ''
135
- basic_filter: BasicFilter = None
136
- filter_type: str = 'basic'
137
-
138
-
139
- @dataclass
140
- class SelectInputs:
141
- """A container for a list of `SelectInput` objects, providing helper methods for managing selections."""
142
- renames: List[SelectInput]
143
-
144
- @property
145
- def old_cols(self) -> Set:
146
- """Returns a set of original column names to be kept in the selection."""
147
- return set(v.old_name for v in self.renames if v.keep)
148
-
149
- @property
150
- def new_cols(self) -> Set:
151
- """Returns a set of new (renamed) column names to be kept in the selection."""
152
- return set(v.new_name for v in self.renames if v.keep)
153
-
154
- @property
155
- def rename_table(self):
156
- """Generates a dictionary for use in Polars' `.rename()` method."""
157
- return {v.old_name: v.new_name for v in self.renames if v.is_available and (v.keep or v.join_key)}
158
-
159
- def get_select_cols(self, include_join_key: bool = True):
160
- """Gets a list of original column names to select from the source DataFrame."""
161
- return [v.old_name for v in self.renames if v.keep or (v.join_key and include_join_key)]
162
-
163
- def has_drop_cols(self) -> bool:
164
- """Checks if any column is marked to be dropped from the selection."""
165
- return any(not v.keep for v in self.renames)
166
167
 
167
- @property
168
- def drop_columns(self) -> List[SelectInput]:
169
- """Returns a list of column names that are marked to be dropped from the selection."""
170
- return [v for v in self.renames if not v.keep and v.is_available]
168
+ class BasicFilter(BaseModel):
169
+ """Defines a simple, single-condition filter (e.g., 'column' 'equals' 'value')."""
170
+ field: Optional[str] = ''
171
+ filter_type: Optional[str] = ''
172
+ filter_value: Optional[str] = ''
171
173
 
172
- @property
173
- def non_jk_drop_columns(self) -> List[SelectInput]:
174
- return [v for v in self.renames if not v.keep and v.is_available and not v.join_key]
174
+ def __init__(self, field: str = None, filter_type: str = None, filter_value: str = None, **data):
175
+ if field is not None:
176
+ data['field'] = field
177
+ if filter_type is not None:
178
+ data['filter_type'] = filter_type
179
+ if filter_value is not None:
180
+ data['filter_value'] = filter_value
181
+ super().__init__(**data)
175
182
 
176
- def __add__(self, other: "SelectInput"):
177
- """Allows adding a SelectInput using the '+' operator."""
178
- self.renames.append(other)
179
183
 
180
- def append(self, other: "SelectInput"):
181
- """Appends a new SelectInput to the list of renames."""
182
- self.renames.append(other)
184
+ class FilterInput(BaseModel):
185
+ """Defines the settings for a filter operation, supporting basic or advanced (expression-based) modes."""
186
+ advanced_filter: Optional[str] = ''
187
+ basic_filter: Optional[BasicFilter] = None
188
+ filter_type: Optional[str] = 'basic'
189
+
190
+ def __init__(self, advanced_filter: str = None, basic_filter: BasicFilter = None,
191
+ filter_type: str = None, **data):
192
+ if advanced_filter is not None:
193
+ data['advanced_filter'] = advanced_filter
194
+ if basic_filter is not None:
195
+ data['basic_filter'] = basic_filter
196
+ if filter_type is not None:
197
+ data['filter_type'] = filter_type
198
+ super().__init__(**data)
199
+
200
+
201
+ class SelectInputs(BaseModel):
202
+ """A container for a list of `SelectInput` objects (pure data, no logic)."""
203
+ renames: List[SelectInput] = Field(default_factory=list)
204
+
205
+ def __init__(self, renames: List[SelectInput] = None, **kwargs):
206
+ if renames is not None:
207
+ kwargs['renames'] = renames
208
+ else:
209
+ kwargs['renames'] = []
210
+ super().__init__(**kwargs)
183
211
 
184
- def remove_select_input(self, old_key: str):
185
- """Removes a SelectInput from the list based on its original name."""
186
- self.renames = [rename for rename in self.renames if rename.old_name != old_key]
212
+ def to_yaml_dict(self) -> JoinInputsYaml:
213
+ """Serialize for YAML output."""
214
+ return {"select": [r.to_yaml_dict() for r in self.renames]}
187
215
 
188
- def unselect_field(self, old_key: str):
189
- """Marks a field to be dropped from the final selection by setting `keep` to False."""
190
- for rename in self.renames:
191
- if old_key == rename.old_name:
192
- rename.keep = False
216
+ @classmethod
217
+ def from_yaml_dict(cls, data: dict) -> "SelectInputs":
218
+ """Load from slim YAML format. Supports both 'select' (new) and 'renames' (internal)."""
219
+ items = data.get("select", data.get("renames", []))
220
+ return cls(renames=[SelectInput.from_yaml_dict(item) for item in items])
193
221
 
194
222
  @classmethod
195
- def create_from_list(cls, col_list: List[str]):
223
+ def create_from_list(cls, col_list: List[str]) -> "SelectInputs":
196
224
  """Creates a SelectInputs object from a simple list of column names."""
197
- return cls([SelectInput(c) for c in col_list])
225
+ return cls(renames=[SelectInput(old_name=c) for c in col_list])
198
226
 
199
227
  @classmethod
200
- def create_from_pl_df(cls, df: pl.DataFrame | pl.LazyFrame):
228
+ def create_from_pl_df(cls, df: pl.DataFrame | pl.LazyFrame) -> "SelectInputs":
201
229
  """Creates a SelectInputs object from a Polars DataFrame's columns."""
202
- return cls([SelectInput(c) for c in df.columns])
203
-
204
- def get_select_input_on_old_name(self, old_name: str) -> SelectInput | None:
205
- return next((v for v in self.renames if v.old_name == old_name), None)
206
-
207
- def get_select_input_on_new_name(self, old_name: str) -> SelectInput | None:
208
- return next((v for v in self.renames if v.new_name == old_name), None)
230
+ return cls(renames=[SelectInput(old_name=c) for c in df.columns])
209
231
 
210
232
 
211
233
  class JoinInputs(SelectInputs):
212
- """Extends `SelectInputs` with functionality specific to join operations, like handling join keys."""
213
-
214
- def __init__(self, renames: List[SelectInput]):
215
- self.renames = renames
216
-
217
- @property
218
- def join_key_selects(self) -> List[SelectInput]:
219
- """Returns only the `SelectInput` objects that are marked as join keys."""
220
- return [v for v in self.renames if v.join_key]
234
+ """Data model for join-specific select inputs (extends SelectInputs)."""
221
235
 
222
- def get_join_key_renames(self, side: SideLit, filter_drop: bool = False) -> JoinKeyRenameResponse:
223
- """Gets the temporary rename mapping for all join keys on one side of a join."""
224
- return JoinKeyRenameResponse(
225
- side,
226
- [JoinKeyRename(jk.new_name,
227
- construct_join_key_name(side, jk.new_name))
228
- for jk in self.join_key_selects if jk.keep or not filter_drop]
229
- )
230
-
231
- def get_join_key_rename_mapping(self, side: SideLit) -> Dict[str, str]:
232
- """Returns a dictionary mapping original join key names to their temporary names."""
233
- return {jkr[0]: jkr[1] for jkr in self.get_join_key_renames(side)[1]}
236
+ def __init__(self, renames: List[SelectInput] = None, **kwargs):
237
+ if renames is not None:
238
+ kwargs['renames'] = renames
239
+ else:
240
+ kwargs['renames'] = []
241
+ super().__init__(**kwargs)
234
242
 
235
243
 
236
- @dataclass
237
- class JoinMap:
244
+ class JoinMap(BaseModel):
238
245
  """Defines a single mapping between a left and right column for a join key."""
239
- left_col: str
240
- right_col: str
241
-
242
-
243
- class JoinSelectMixin:
244
- """A mixin providing common methods for join-like operations that involve left and right inputs."""
245
- left_select: JoinInputs = None
246
- right_select: JoinInputs = None
247
-
248
- @staticmethod
249
- def parse_select(select: List[SelectInput] | List[str] | List[Dict]) -> JoinInputs | None:
250
- """Parses various input formats into a standardized `JoinInputs` object."""
251
- if all(isinstance(c, SelectInput) for c in select):
252
- return JoinInputs(select)
253
- elif all(isinstance(c, dict) for c in select):
254
- return JoinInputs([SelectInput(**c.__dict__) for c in select])
255
- elif isinstance(select, dict):
256
- renames = select.get('renames')
257
- if renames:
258
- return JoinInputs([SelectInput(**c) for c in renames])
259
- elif all(isinstance(c, str) for c in select):
260
- return JoinInputs([SelectInput(s, s) for s in select])
261
-
262
- def auto_generate_new_col_name(self, old_col_name: str, side: str) -> str:
263
- """Generates a new, non-conflicting column name by adding a suffix if necessary."""
264
- current_names = self.left_select.new_cols & self.right_select.new_cols
265
- if old_col_name not in current_names:
266
- return old_col_name
267
- while True:
268
- if old_col_name not in current_names:
269
- return old_col_name
270
- old_col_name = f'{side}_{old_col_name}'
271
-
272
- def add_new_select_column(self, select_input: SelectInput, side: str):
273
- """Adds a new column to the selection for either the left or right side."""
274
- selects = self.right_select if side == 'right' else self.left_select
275
- select_input.new_name = self.auto_generate_new_col_name(select_input.old_name, side=side)
276
- selects.__add__(select_input)
277
-
278
-
279
- @dataclass
280
- class CrossJoinInput(JoinSelectMixin):
281
- """Defines the settings for a cross join operation, including column selections for both inputs."""
282
- left_select: SelectInputs = None
283
- right_select: SelectInputs = None
246
+ left_col: Optional[str] = None
247
+ right_col: Optional[str] = None
248
+
249
+ def __init__(self, left_col: str = None, right_col: str = None, **data):
250
+ if left_col is not None:
251
+ data['left_col'] = left_col
252
+ if right_col is not None:
253
+ data['right_col'] = right_col
254
+ super().__init__(**data)
255
+
256
+ @model_validator(mode='after')
257
+ def set_default_right_col(self):
258
+ """If right_col is None, default it to left_col."""
259
+ if self.right_col is None:
260
+ self.right_col = self.left_col
261
+ return self
262
+
263
+
264
+ class CrossJoinInput(BaseModel):
265
+ """Data model for cross join operations."""
266
+ left_select: JoinInputs
267
+ right_select: JoinInputs
268
+
269
+ @model_validator(mode='before')
270
+ @classmethod
271
+ def parse_inputs(cls, data: Any) -> Any:
272
+ """Parse flexible input formats before validation."""
273
+ if isinstance(data, dict):
274
+ # Parse join_mapping
275
+ if 'join_mapping' in data:
276
+ data['join_mapping'] = cls._parse_join_mapping(data['join_mapping'])
284
277
 
285
- def __init__(self, left_select: List[SelectInput] | List[str],
286
- right_select: List[SelectInput] | List[str]):
287
- """Initializes the CrossJoinInput with selections for left and right tables."""
288
- self.left_select = self.parse_select(left_select)
289
- self.right_select = self.parse_select(right_select)
278
+ # Parse left_select
279
+ if 'left_select' in data:
280
+ data['left_select'] = cls._parse_select(data['left_select'])
290
281
 
291
- @property
292
- def overlapping_records(self):
293
- """Finds column names that would conflict after the join."""
294
- return self.left_select.new_cols & self.right_select.new_cols
282
+ # Parse right_select
283
+ if 'right_select' in data:
284
+ data['right_select'] = cls._parse_select(data['right_select'])
295
285
 
296
- def auto_rename(self):
297
- """Automatically renames columns on the right side to prevent naming conflicts."""
298
- overlapping_records = self.overlapping_records
299
- while len(overlapping_records) > 0:
300
- for right_col in self.right_select.renames:
301
- if right_col.new_name in overlapping_records:
302
- right_col.new_name = 'right_' + right_col.new_name
303
- overlapping_records = self.overlapping_records
286
+ return data
304
287
 
288
+ @staticmethod
289
+ def _parse_join_mapping(join_mapping: Any) -> List[JoinMap]:
290
+ """Parse various join_mapping formats."""
291
+ # Already a list of JoinMaps
292
+ if isinstance(join_mapping, list):
293
+ result = []
294
+ for jm in join_mapping:
295
+ if isinstance(jm, JoinMap):
296
+ result.append(jm)
297
+ elif isinstance(jm, dict):
298
+ result.append(JoinMap(**jm))
299
+ elif isinstance(jm, (tuple, list)) and len(jm) == 2:
300
+ result.append(JoinMap(left_col=jm[0], right_col=jm[1]))
301
+ elif isinstance(jm, str):
302
+ result.append(JoinMap(left_col=jm, right_col=jm))
303
+ else:
304
+ raise ValueError(f"Invalid join mapping item: {jm}")
305
+ return result
306
+
307
+ # Single JoinMap
308
+ if isinstance(join_mapping, JoinMap):
309
+ return [join_mapping]
310
+
311
+ # String: same column on both sides
312
+ if isinstance(join_mapping, str):
313
+ return [JoinMap(left_col=join_mapping, right_col=join_mapping)]
314
+
315
+ # Tuple: (left, right)
316
+ if isinstance(join_mapping, tuple) and len(join_mapping) == 2:
317
+ return [JoinMap(left_col=join_mapping[0], right_col=join_mapping[1])]
318
+
319
+ raise ValueError(f"Invalid join_mapping format: {type(join_mapping)}")
305
320
 
306
- @dataclass
307
- class JoinInput(JoinSelectMixin):
308
- """Defines the settings for a standard SQL-style join, including keys, strategy, and selections."""
321
+ @staticmethod
322
+ def _parse_select(select: Any) -> JoinInputs:
323
+ """Parse various select input formats."""
324
+ # Already JoinInputs
325
+ if isinstance(select, JoinInputs):
326
+ return select
327
+
328
+ # List of SelectInput objects
329
+ if isinstance(select, list):
330
+ if all(isinstance(s, SelectInput) for s in select):
331
+ return JoinInputs(renames=select)
332
+ elif all(isinstance(s, str) for s in select):
333
+ return JoinInputs(renames=[SelectInput(old_name=s) for s in select])
334
+ elif all(isinstance(s, dict) for s in select):
335
+ return JoinInputs(renames=[SelectInput(**s) for s in select])
336
+
337
+ # Dict with 'select' (new YAML) or 'renames' (internal) key
338
+ if isinstance(select, dict):
339
+ if 'select' in select:
340
+ return JoinInputs(renames=[SelectInput.from_yaml_dict(s) for s in select['select']])
341
+ if 'renames' in select:
342
+ return JoinInputs(**select)
343
+
344
+ raise ValueError(f"Invalid select format: {type(select)}")
345
+
346
+ def __init__(self,
347
+ left_select: Union[JoinInputs, List[SelectInput], List[str]] = None,
348
+ right_select: Union[JoinInputs, List[SelectInput], List[str]] = None,
349
+ **data):
350
+ """Custom init for backward compatibility with positional arguments."""
351
+ if left_select is not None:
352
+ data['left_select'] = left_select
353
+ if right_select is not None:
354
+ data['right_select'] = right_select
355
+ super().__init__(**data)
356
+
357
+ def to_yaml_dict(self) -> CrossJoinInputYaml:
358
+ """Serialize for YAML output."""
359
+ return {
360
+ "left_select": self.left_select.to_yaml_dict(),
361
+ "right_select": self.right_select.to_yaml_dict(),
362
+ }
363
+
364
+
365
+ class JoinInput(BaseModel):
366
+ """Data model for standard SQL-style join operations."""
309
367
  join_mapping: List[JoinMap]
310
- left_select: JoinInputs = None
311
- right_select: JoinInputs = None
368
+ left_select: JoinInputs
369
+ right_select: JoinInputs
312
370
  how: JoinStrategy = 'inner'
313
371
 
314
- @staticmethod
315
- def parse_join_mapping(join_mapping: any) -> List[JoinMap]:
316
- """Parses various input formats for join keys into a standardized list of `JoinMap` objects."""
317
- if isinstance(join_mapping, (tuple, list)):
318
- assert len(join_mapping) > 0
319
- if all(isinstance(jm, dict) for jm in join_mapping):
320
- join_mapping = [JoinMap(**jm) for jm in join_mapping]
321
-
322
- if not isinstance(join_mapping[0], JoinMap):
323
- assert len(join_mapping) <= 2
324
- if len(join_mapping) == 2:
325
- assert isinstance(join_mapping[0], str) and isinstance(join_mapping[1], str)
326
- join_mapping = [JoinMap(*join_mapping)]
327
- elif isinstance(join_mapping[0], str):
328
- join_mapping = [JoinMap(join_mapping[0], join_mapping[0])]
329
- elif isinstance(join_mapping, str):
330
- join_mapping = [JoinMap(join_mapping, join_mapping)]
331
- else:
332
- raise Exception('No valid join mapping as input')
333
- return join_mapping
334
-
335
- def __init__(self, join_mapping: List[JoinMap] | Tuple[str, str] | str,
336
- left_select: List[SelectInput] | List[str],
337
- right_select: List[SelectInput] | List[str],
338
- how: JoinStrategy = 'inner'):
339
- """Initializes the JoinInput with keys, selections, and join strategy."""
340
- self.join_mapping = self.parse_join_mapping(join_mapping)
341
- self.left_select = self.parse_select(left_select)
342
- self.right_select = self.parse_select(right_select)
343
- self.set_join_keys()
344
- self.how = how
345
-
346
- def set_join_keys(self):
347
- """Marks the `SelectInput` objects corresponding to join keys."""
348
- [setattr(v, "join_key", v.old_name in self._left_join_keys) for v in self.left_select.renames]
349
- [setattr(v, "join_key", v.old_name in self._right_join_keys) for v in self.right_select.renames]
350
-
351
- def get_join_key_renames(self, filter_drop: bool = False) -> FullJoinKeyResponse:
352
- """Gets the temporary rename mappings for the join keys on both sides."""
353
- return FullJoinKeyResponse(self.left_select.get_join_key_renames(side="left", filter_drop=filter_drop),
354
- self.right_select.get_join_key_renames(side="right", filter_drop=filter_drop))
355
-
356
- def get_names_for_table_rename(self) -> List[JoinMap]:
357
- new_mappings: List[JoinMap] = []
358
- left_rename_table, right_rename_table = self.left_select.rename_table, self.right_select.rename_table
359
- for join_map in self.join_mapping:
360
- new_mappings.append(JoinMap(left_rename_table.get(join_map.left_col, join_map.left_col),
361
- right_rename_table.get(join_map.right_col, join_map.right_col)
362
- )
363
- )
364
- return new_mappings
365
-
366
- @property
367
- def _left_join_keys(self) -> Set:
368
- """Returns a set of the left-side join key column names."""
369
- return set(jm.left_col for jm in self.join_mapping)
370
-
371
- @property
372
- def _right_join_keys(self) -> Set:
373
- """Returns a set of the right-side join key column names."""
374
- return set(jm.right_col for jm in self.join_mapping)
372
+ @model_validator(mode='before')
373
+ @classmethod
374
+ def parse_inputs(cls, data: Any) -> Any:
375
+ """Parse flexible input formats before validation."""
376
+ if isinstance(data, dict):
377
+ # Parse join_mapping
378
+ if 'join_mapping' in data:
379
+ data['join_mapping'] = cls._parse_join_mapping(data['join_mapping'])
375
380
 
376
- @property
377
- def left_join_keys(self) -> List[str]:
378
- """Returns an ordered list of the left-side join key column names to be used in the join."""
379
- return [jm.left_col for jm in self.used_join_mapping]
381
+ # Parse left_select
382
+ if 'left_select' in data:
383
+ data['left_select'] = cls._parse_select(data['left_select'])
380
384
 
381
- @property
382
- def right_join_keys(self) -> List[str]:
383
- """Returns an ordered list of the right-side join key column names to be used in the join."""
384
- return [jm.right_col for jm in self.used_join_mapping]
385
+ # Parse right_select
386
+ if 'right_select' in data:
387
+ data['right_select'] = cls._parse_select(data['right_select'])
385
388
 
386
- @property
387
- def overlapping_records(self):
388
- if self.how in ('left', 'right', 'inner'):
389
- return self.left_select.new_cols & self.right_select.new_cols
390
- else:
391
- return self.left_select.new_cols & self.right_select.new_cols
392
-
393
- def auto_rename(self):
394
- """Automatically renames columns on the right side to prevent naming conflicts."""
395
- self.set_join_keys()
396
- overlapping_records = self.overlapping_records
397
- while len(overlapping_records) > 0:
398
- for right_col in self.right_select.renames:
399
- if right_col.new_name in overlapping_records:
400
- right_col.new_name = right_col.new_name + '_right'
401
- overlapping_records = self.overlapping_records
402
-
403
- @property
404
- def used_join_mapping(self) -> List[JoinMap]:
405
- """Returns the final join mapping after applying all renames and transformations."""
406
- new_mappings: List[JoinMap] = []
407
- left_rename_table, right_rename_table = self.left_select.rename_table, self.right_select.rename_table
408
- left_join_rename_mapping: Dict[str, str] = self.left_select.get_join_key_rename_mapping("left")
409
- right_join_rename_mapping: Dict[str, str] = self.right_select.get_join_key_rename_mapping("right")
410
- for join_map in self.join_mapping:
411
- # del self.right_select.rename_table, self.left_select.rename_table
412
- new_mappings.append(JoinMap(left_join_rename_mapping.get(left_rename_table.get(join_map.left_col, join_map.left_col)),
413
- right_join_rename_mapping.get(right_rename_table.get(join_map.right_col, join_map.right_col))
414
- )
415
- )
416
- return new_mappings
389
+ return data
417
390
 
391
+ @staticmethod
392
+ def _parse_join_mapping(join_mapping: Any) -> List[JoinMap]:
393
+ """Parse various join_mapping formats."""
394
+ # Already a list of JoinMaps
395
+ if isinstance(join_mapping, list):
396
+ result = []
397
+ for jm in join_mapping:
398
+ if isinstance(jm, JoinMap):
399
+ result.append(jm)
400
+ elif isinstance(jm, dict):
401
+ result.append(JoinMap(**jm))
402
+ elif isinstance(jm, (tuple, list)) and len(jm) == 2:
403
+ result.append(JoinMap(left_col=jm[0], right_col=jm[1]))
404
+ elif isinstance(jm, str):
405
+ result.append(JoinMap(left_col=jm, right_col=jm))
406
+ else:
407
+ raise ValueError(f"Invalid join mapping item: {jm}")
408
+ return result
409
+
410
+ # Single JoinMap
411
+ if isinstance(join_mapping, JoinMap):
412
+ return [join_mapping]
413
+
414
+ # String: same column on both sides
415
+ if isinstance(join_mapping, str):
416
+ return [JoinMap(left_col=join_mapping, right_col=join_mapping)]
417
+
418
+ # Tuple: (left, right)
419
+ if isinstance(join_mapping, tuple) and len(join_mapping) == 2:
420
+ return [JoinMap(left_col=join_mapping[0], right_col=join_mapping[1])]
421
+
422
+ raise ValueError(f"Invalid join_mapping format: {type(join_mapping)}")
418
423
 
419
- @dataclass
420
- class FuzzyMatchInput(JoinInput):
421
- """Extends `JoinInput` with settings specific to fuzzy matching, such as the matching algorithm and threshold."""
424
+ @staticmethod
425
+ def _parse_select(select: Any) -> JoinInputs:
426
+ """Parse various select input formats."""
427
+ # Already JoinInputs
428
+ if isinstance(select, JoinInputs):
429
+ return select
430
+
431
+ # List of SelectInput objects
432
+ if isinstance(select, list):
433
+ if all(isinstance(s, SelectInput) for s in select):
434
+ return JoinInputs(renames=select)
435
+ elif all(isinstance(s, str) for s in select):
436
+ return JoinInputs(renames=[SelectInput(old_name=s) for s in select])
437
+ elif all(isinstance(s, dict) for s in select):
438
+ return JoinInputs(renames=[SelectInput(**s) for s in select])
439
+
440
+ # Dict with 'select' (new YAML) or 'renames' (internal) key
441
+ if isinstance(select, dict):
442
+ if 'select' in select:
443
+ return JoinInputs(renames=[SelectInput.from_yaml_dict(s) for s in select['select']])
444
+ if 'renames' in select:
445
+ return JoinInputs(**select)
446
+
447
+ raise ValueError(f"Invalid select format: {type(select)}")
448
+
449
+ def __init__(self,
450
+ join_mapping: Union[List[JoinMap], JoinMap, Tuple[str, str], str, List[Tuple], List[str]] = None,
451
+ left_select: Union[JoinInputs, List[SelectInput], List[str]] = None,
452
+ right_select: Union[JoinInputs, List[SelectInput], List[str]] = None,
453
+ how: JoinStrategy = 'inner',
454
+ **data):
455
+ """Custom init for backward compatibility with positional arguments."""
456
+ if join_mapping is not None:
457
+ data['join_mapping'] = join_mapping
458
+ if left_select is not None:
459
+ data['left_select'] = left_select
460
+ if right_select is not None:
461
+ data['right_select'] = right_select
462
+ if how is not None:
463
+ data['how'] = how
464
+
465
+ super().__init__(**data)
466
+
467
+ def to_yaml_dict(self) -> JoinInputYaml:
468
+ """Serialize for YAML output."""
469
+ return {
470
+ "join_mapping": [{"left_col": jm.left_col, "right_col": jm.right_col} for jm in self.join_mapping],
471
+ "left_select": self.left_select.to_yaml_dict(),
472
+ "right_select": self.right_select.to_yaml_dict(),
473
+ "how": self.how,
474
+ }
475
+
476
+
477
+ class FuzzyMatchInput(BaseModel):
478
+ """Data model for fuzzy matching join operations."""
422
479
  join_mapping: List[FuzzyMapping]
480
+ left_select: JoinInputs
481
+ right_select: JoinInputs
482
+ how: JoinStrategy = 'inner'
423
483
  aggregate_output: bool = False
424
484
 
425
- @staticmethod
426
- def parse_fuzz_mapping(fuzz_mapping: List[FuzzyMapping] | Tuple[str, str] | str) -> List[FuzzyMapping]:
427
- if isinstance(fuzz_mapping, (tuple, list)):
428
- assert len(fuzz_mapping) > 0
429
- if all(isinstance(fm, dict) for fm in fuzz_mapping):
430
- fuzz_mapping = [FuzzyMapping(**fm) for fm in fuzz_mapping]
485
+ def __init__(self,
486
+ left_select: Union[JoinInputs, List[SelectInput], List[str]] = None,
487
+ right_select: Union[JoinInputs, List[SelectInput], List[str]] = None,
488
+ **data):
489
+ """Custom init for backward compatibility with positional arguments."""
490
+ if left_select is not None:
491
+ data['left_select'] = left_select
492
+ if right_select is not None:
493
+ data['right_select'] = right_select
494
+
495
+ super().__init__(**data)
496
+
497
+ def to_yaml_dict(self) -> FuzzyMatchInputYaml:
498
+ """Serialize for YAML output."""
499
+ return {
500
+ "join_mapping": [asdict(jm) for jm in self.join_mapping],
501
+ "left_select": self.left_select.to_yaml_dict(),
502
+ "right_select": self.right_select.to_yaml_dict(),
503
+ "how": self.how,
504
+ "aggregate_output": self.aggregate_output,
505
+ }
431
506
 
432
- if not isinstance(fuzz_mapping[0], FuzzyMapping):
433
- assert len(fuzz_mapping) <= 2
434
- if len(fuzz_mapping) == 2:
435
- assert isinstance(fuzz_mapping[0], str) and isinstance(fuzz_mapping[1], str)
436
- fuzz_mapping = [FuzzyMapping(*fuzz_mapping)]
437
- elif isinstance(fuzz_mapping[0], str):
438
- fuzz_mapping = [FuzzyMapping(fuzz_mapping[0], fuzz_mapping[0])]
439
- elif isinstance(fuzz_mapping, str):
440
- fuzz_mapping = [FuzzyMapping(fuzz_mapping, fuzz_mapping)]
441
- elif isinstance(fuzz_mapping, FuzzyMapping):
442
- fuzz_mapping = [fuzz_mapping]
443
- else:
444
- raise Exception('No valid join mapping as input')
445
- return fuzz_mapping
446
-
447
- def __init__(self, join_mapping: List[FuzzyMapping] | Tuple[str, str] | str, left_select: List[SelectInput] | List[str],
448
- right_select: List[SelectInput] | List[str], aggregate_output: bool = False, how: JoinStrategy = 'inner'):
449
- self.join_mapping = self.parse_fuzz_mapping(join_mapping)
450
- self.left_select = self.parse_select(left_select)
451
- self.right_select = self.parse_select(right_select)
452
- self.how = how
453
- for jm in self.join_mapping:
454
-
455
- if jm.right_col not in {v.old_name for v in self.right_select.renames}:
456
- self.right_select.append(SelectInput(jm.right_col, keep=False, join_key=True))
457
- if jm.left_col not in {v.old_name for v in self.left_select.renames}:
458
- self.left_select.append(SelectInput(jm.left_col, keep=False, join_key=True))
459
- [setattr(v, "join_key", v.old_name in self._left_join_keys) for v in self.left_select.renames]
460
- [setattr(v, "join_key", v.old_name in self._right_join_keys) for v in self.right_select.renames]
461
- self.aggregate_output = aggregate_output
507
+ @staticmethod
508
+ def _parse_select(select: Any) -> JoinInputs:
509
+ """Parse various select input formats."""
510
+ # Already JoinInputs
511
+ if isinstance(select, JoinInputs):
512
+ return select
513
+
514
+ # List of SelectInput objects
515
+ if isinstance(select, list):
516
+ if all(isinstance(s, SelectInput) for s in select):
517
+ return JoinInputs(renames=select)
518
+ elif all(isinstance(s, str) for s in select):
519
+ return JoinInputs(renames=[SelectInput(old_name=s) for s in select])
520
+ elif all(isinstance(s, dict) for s in select):
521
+ return JoinInputs(renames=[SelectInput(**s) for s in select])
522
+
523
+ # Dict with 'select' (new YAML) or 'renames' (internal) key
524
+ if isinstance(select, dict):
525
+ if 'select' in select:
526
+ return JoinInputs(renames=[SelectInput.from_yaml_dict(s) for s in select['select']])
527
+ if 'renames' in select:
528
+ return JoinInputs(**select)
529
+
530
+ raise ValueError(f"Invalid select format: {type(select)}")
531
+
532
+ @model_validator(mode='before')
533
+ @classmethod
534
+ def parse_inputs(cls, data: Any) -> Any:
535
+ """Parse flexible input formats before validation."""
536
+ if isinstance(data, dict):
537
+ # Parse left_select
538
+ if 'left_select' in data:
539
+ data['left_select'] = cls._parse_select(data['left_select'])
462
540
 
463
- @property
464
- def overlapping_records(self):
465
- return self.left_select.new_cols & self.right_select.new_cols
541
+ # Parse right_select
542
+ if 'right_select' in data:
543
+ data['right_select'] = cls._parse_select(data['right_select'])
466
544
 
467
- @property
468
- def fuzzy_maps(self) -> List[FuzzyMapping]:
469
- """Returns the final fuzzy mappings after applying all column renames."""
470
- new_mappings = []
471
- left_rename_table, right_rename_table = self.left_select.rename_table, self.right_select.rename_table
472
- for org_fuzzy_map in self.join_mapping:
473
- right_col = right_rename_table.get(org_fuzzy_map.right_col)
474
- left_col = left_rename_table.get(org_fuzzy_map.left_col)
475
- if right_col != org_fuzzy_map.right_col or left_col != org_fuzzy_map.left_col:
476
- new_mapping = deepcopy(org_fuzzy_map)
477
- new_mapping.left_col = left_col
478
- new_mapping.right_col = right_col
479
- new_mappings.append(new_mapping)
480
- else:
481
- new_mappings.append(org_fuzzy_map)
482
- return new_mappings
545
+ return data
483
546
 
484
547
 
485
- @dataclass
486
- class AggColl:
548
+ class AggColl(BaseModel):
487
549
  """
488
550
  A data class that represents a single aggregation operation for a group by operation.
489
551
 
@@ -492,7 +554,7 @@ class AggColl:
492
554
  old_name : str
493
555
  The name of the column in the original DataFrame to be aggregated.
494
556
 
495
- agg : Any
557
+ agg : str
496
558
  The aggregation function to use. This can be a string representing a built-in function or a custom function.
497
559
 
498
560
  new_name : Optional[str]
@@ -514,18 +576,36 @@ class AggColl:
514
576
  """
515
577
  old_name: str
516
578
  agg: str
517
- new_name: Optional[str]
579
+ new_name: Optional[str] = None
518
580
  output_type: Optional[str] = None
519
581
 
520
- def __init__(self, old_name: str, agg: str, new_name: str = None, output_type: str = None):
521
- """Initializes an aggregation column with its source, function, and new name."""
522
- self.old_name = str(old_name)
523
- if agg != 'groupby':
524
- self.new_name = new_name if new_name is not None else self.old_name + "_" + agg
525
- else:
526
- self.new_name = new_name if new_name is not None else self.old_name
527
- self.output_type = output_type if output_type is not None else get_func_type_mapping(agg)
528
- self.agg = agg
582
+ def __init__(self, old_name: str, agg: str, new_name: Optional[str] = None, output_type: Optional[str] = None):
583
+ data = {'old_name': old_name, 'agg': agg}
584
+ if new_name is not None:
585
+ data['new_name'] = new_name
586
+ if output_type is not None:
587
+ data['output_type'] = output_type
588
+
589
+ super().__init__(**data)
590
+
591
+ @model_validator(mode='after')
592
+ def set_defaults(self):
593
+ """Set default new_name and output_type based on agg function."""
594
+ # Set new_name
595
+ if self.new_name is None:
596
+ if self.agg != 'groupby':
597
+ self.new_name = self.old_name + "_" + self.agg
598
+ else:
599
+ self.new_name = self.old_name
600
+
601
+ # Set output_type
602
+ if self.output_type is None:
603
+ self.output_type = get_func_type_mapping(self.agg)
604
+
605
+ # Ensure old_name is a string
606
+ self.old_name = str(self.old_name)
607
+
608
+ return self
529
609
 
530
610
  @property
531
611
  def agg_func(self):
@@ -538,16 +618,12 @@ class AggColl:
538
618
  return getattr(pl, self.agg) if isinstance(self.agg, str) else self.agg
539
619
 
540
620
 
541
- @dataclass
542
- class GroupByInput:
621
+ class GroupByInput(BaseModel):
543
622
  """
544
623
  A data class that represents the input for a group by operation.
545
624
 
546
625
  Attributes
547
626
  ----------
548
- group_columns : List[str]
549
- A list of column names to group the DataFrame by. These column(s) will be set as the DataFrame index.
550
-
551
627
  agg_cols : List[AggColl]
552
628
  A list of `AggColl` objects that specify the aggregation operations to perform on the DataFrame columns
553
629
  after grouping. Each `AggColl` object should specify the column to be aggregated and the aggregation
@@ -556,14 +632,18 @@ class GroupByInput:
556
632
  Example
557
633
  --------
558
634
  group_by_input = GroupByInput(
559
- agg_cols=[AggColl(old_name='ix', agg='groupby'), AggColl(old_name='groups', agg='groupby'), AggColl(old_name='col1', agg='sum'), AggColl(old_name='col2', agg='mean')]
635
+ agg_cols=[AggColl(old_name='ix', agg='groupby'), AggColl(old_name='groups', agg='groupby'),
636
+ AggColl(old_name='col1', agg='sum'), AggColl(old_name='col2', agg='mean')]
560
637
  )
561
638
  """
562
639
  agg_cols: List[AggColl]
563
640
 
641
+ def __init__(self, agg_cols: List[AggColl]):
642
+ """Backwards compatibility implementation"""
643
+ super().__init__(agg_cols=agg_cols)
564
644
 
565
- @dataclass
566
- class PivotInput:
645
+
646
+ class PivotInput(BaseModel):
567
647
  """Defines the settings for a pivot (long-to-wide) operation."""
568
648
  index_columns: List[str]
569
649
  pivot_column: str
@@ -577,11 +657,13 @@ class PivotInput:
577
657
 
578
658
  def get_group_by_input(self) -> GroupByInput:
579
659
  """Constructs the `GroupByInput` needed for the pre-aggregation step of the pivot."""
580
- group_by_cols = [AggColl(c, 'groupby') for c in self.grouped_columns]
581
- agg_cols = [AggColl(self.value_col, agg=aggregation, new_name=aggregation) for aggregation in self.aggregations]
582
- return GroupByInput(group_by_cols+agg_cols)
660
+ group_by_cols = [AggColl(old_name=c, agg='groupby') for c in self.grouped_columns]
661
+ agg_cols = [AggColl(old_name=self.value_col, agg=aggregation, new_name=aggregation)
662
+ for aggregation in self.aggregations]
663
+ return GroupByInput(agg_cols=group_by_cols + agg_cols)
583
664
 
584
665
  def get_index_columns(self) -> List[pl.col]:
666
+ """Returns the index columns as Polars column expressions."""
585
667
  return [pl.col(c) for c in self.index_columns]
586
668
 
587
669
  def get_pivot_column(self) -> pl.Expr:
@@ -593,24 +675,21 @@ class PivotInput:
593
675
  return pl.struct([pl.col(c) for c in self.aggregations]).alias('vals')
594
676
 
595
677
 
596
- @dataclass
597
- class SortByInput:
678
+ class SortByInput(BaseModel):
598
679
  """Defines a single sort condition on a column, including the direction."""
599
680
  column: str
600
- how: str = 'asc'
681
+ how: Optional[str] = 'asc'
601
682
 
602
683
 
603
- @dataclass
604
- class RecordIdInput:
684
+ class RecordIdInput(BaseModel):
605
685
  """Defines settings for adding a record ID (row number) column to the data."""
606
686
  output_column_name: str = 'record_id'
607
687
  offset: int = 1
608
688
  group_by: Optional[bool] = False
609
- group_by_columns: Optional[List[str]] = field(default_factory=list)
689
+ group_by_columns: Optional[List[str]] = Field(default_factory=list)
610
690
 
611
691
 
612
- @dataclass
613
- class TextToRowsInput:
692
+ class TextToRowsInput(BaseModel):
614
693
  """Defines settings for splitting a text column into multiple rows based on a delimiter."""
615
694
  column_to_split: str
616
695
  output_column_name: Optional[str] = None
@@ -619,22 +698,14 @@ class TextToRowsInput:
619
698
  split_by_column: Optional[str] = None
620
699
 
621
700
 
622
- @dataclass
623
- class UnpivotInput:
701
+ class UnpivotInput(BaseModel):
624
702
  """Defines settings for an unpivot (wide-to-long) operation."""
625
- index_columns: Optional[List[str]] = field(default_factory=list)
626
- value_columns: Optional[List[str]] = field(default_factory=list)
627
- data_type_selector: Optional[Literal['float', 'all', 'date', 'numeric', 'string']] = None
628
- data_type_selector_mode: Optional[Literal['data_type', 'column']] = 'column'
703
+ model_config = ConfigDict(arbitrary_types_allowed=True)
629
704
 
630
- def __post_init__(self):
631
- """Ensures that list attributes are initialized correctly if they are None."""
632
- if self.index_columns is None:
633
- self.index_columns = []
634
- if self.value_columns is None:
635
- self.value_columns = []
636
- if self.data_type_selector_mode is None:
637
- self.data_type_selector_mode = 'column'
705
+ index_columns: List[str] = Field(default_factory=list)
706
+ value_columns: List[str] = Field(default_factory=list)
707
+ data_type_selector: Optional[Literal['float', 'all', 'date', 'numeric', 'string']] = None
708
+ data_type_selector_mode: Literal['data_type', 'column'] = 'column'
638
709
 
639
710
  @property
640
711
  def data_type_selector_expr(self) -> Optional[Callable]:
@@ -647,30 +718,630 @@ class UnpivotInput:
647
718
  print(f'Could not find the selector: {self.data_type_selector}')
648
719
  return selectors.all
649
720
  return selectors.all
721
+ return None
650
722
 
651
723
 
652
- @dataclass
653
- class UnionInput:
724
+ class UnionInput(BaseModel):
654
725
  """Defines settings for a union (concatenation) operation."""
655
726
  mode: Literal['selective', 'relaxed'] = 'relaxed'
656
727
 
657
728
 
658
- @dataclass
659
- class UniqueInput:
729
+ class UniqueInput(BaseModel):
660
730
  """Defines settings for a uniqueness operation, specifying columns and which row to keep."""
661
731
  columns: Optional[List[str]] = None
662
732
  strategy: Literal["first", "last", "any", "none"] = "any"
663
733
 
664
734
 
665
- @dataclass
666
- class GraphSolverInput:
735
+ class GraphSolverInput(BaseModel):
667
736
  """Defines settings for a graph-solving operation (e.g., finding connected components)."""
668
737
  col_from: str
669
738
  col_to: str
670
739
  output_column_name: Optional[str] = 'graph_group'
671
740
 
672
741
 
673
- @dataclass
674
- class PolarsCodeInput:
742
+ class PolarsCodeInput(BaseModel):
675
743
  """A simple container for a string of user-provided Polars code to be executed."""
676
744
  polars_code: str
745
+
746
+
747
+ class SelectInputsManager:
748
+ """Manager class that provides all query and mutation operations."""
749
+
750
+ def __init__(self, select_inputs: SelectInputs):
751
+ self.select_inputs = select_inputs
752
+
753
+ # === Query Methods (read-only) ===
754
+
755
+ def get_old_cols(self) -> Set[str]:
756
+ """Returns a set of original column names to be kept in the selection."""
757
+ return set(v.old_name for v in self.select_inputs.renames if v.keep)
758
+
759
+ def get_new_cols(self) -> Set[str]:
760
+ """Returns a set of new (renamed) column names to be kept in the selection."""
761
+ return set(v.new_name for v in self.select_inputs.renames if v.keep)
762
+
763
+ def get_rename_table(self) -> dict[str, str]:
764
+ """Generates a dictionary for use in Polars' `.rename()` method."""
765
+ return {v.old_name: v.new_name for v in self.select_inputs.renames
766
+ if v.is_available and (v.keep or v.join_key)}
767
+
768
+ def get_select_cols(self, include_join_key: bool = True) -> List[str]:
769
+ """Gets a list of original column names to select from the source DataFrame."""
770
+ return [v.old_name for v in self.select_inputs.renames
771
+ if v.keep or (v.join_key and include_join_key)]
772
+
773
+ def has_drop_cols(self) -> bool:
774
+ """Checks if any column is marked to be dropped from the selection."""
775
+ return any(not v.keep for v in self.select_inputs.renames)
776
+
777
+ def get_drop_columns(self) -> List[SelectInput]:
778
+ """Returns a list of SelectInput objects that are marked to be dropped."""
779
+ return [v for v in self.select_inputs.renames if not v.keep and v.is_available]
780
+
781
+ def get_non_jk_drop_columns(self) -> List[SelectInput]:
782
+ """Returns drop columns that are not join keys."""
783
+ return [v for v in self.select_inputs.renames
784
+ if not v.keep and v.is_available and not v.join_key]
785
+
786
+ def find_by_old_name(self, old_name: str) -> Optional[SelectInput]:
787
+ """Find SelectInput by original column name."""
788
+ return next((v for v in self.select_inputs.renames if v.old_name == old_name), None)
789
+
790
+ def find_by_new_name(self, new_name: str) -> Optional[SelectInput]:
791
+ """Find SelectInput by new column name."""
792
+ return next((v for v in self.select_inputs.renames if v.new_name == new_name), None)
793
+
794
+ # === Mutation Methods ===
795
+
796
+ def append(self, other: SelectInput) -> None:
797
+ """Appends a new SelectInput to the list of renames."""
798
+ self.select_inputs.renames.append(other)
799
+
800
+ def remove_select_input(self, old_key: str) -> None:
801
+ """Removes a SelectInput from the list based on its original name."""
802
+ self.select_inputs.renames = [
803
+ rename for rename in self.select_inputs.renames
804
+ if rename.old_name != old_key
805
+ ]
806
+
807
+ def unselect_field(self, old_key: str) -> None:
808
+ """Marks a field to be dropped from the final selection by setting `keep` to False."""
809
+ for rename in self.select_inputs.renames:
810
+ if old_key == rename.old_name:
811
+ rename.keep = False
812
+
813
+ # === Backward Compatibility Properties ===
814
+
815
+ @property
816
+ def old_cols(self) -> Set[str]:
817
+ """Backward compatibility: Returns set of old column names."""
818
+ return self.get_old_cols()
819
+
820
+ @property
821
+ def new_cols(self) -> Set[str]:
822
+ """Backward compatibility: Returns set of new column names."""
823
+ return self.get_new_cols()
824
+
825
+ @property
826
+ def rename_table(self) -> dict[str, str]:
827
+ """Backward compatibility: Returns rename table dictionary."""
828
+ return self.get_rename_table()
829
+
830
+ @property
831
+ def drop_columns(self) -> List[SelectInput]:
832
+ """Backward compatibility: Returns list of columns to drop."""
833
+ return self.get_drop_columns()
834
+
835
+ @property
836
+ def non_jk_drop_columns(self) -> List[SelectInput]:
837
+ """Backward compatibility: Returns non-join-key columns to drop."""
838
+ return self.get_non_jk_drop_columns()
839
+
840
+ @property
841
+ def renames(self) -> List[SelectInput]:
842
+ """Backward compatibility: Direct access to renames list."""
843
+ return self.select_inputs.renames
844
+
845
+ def get_select_input_on_old_name(self, old_name: str) -> Optional[SelectInput]:
846
+ """Backward compatibility alias: Find SelectInput by original column name."""
847
+ return self.find_by_old_name(old_name)
848
+
849
+ def get_select_input_on_new_name(self, new_name: str) -> Optional[SelectInput]:
850
+ """Backward compatibility alias: Find SelectInput by new column name."""
851
+ return self.find_by_new_name(new_name)
852
+
853
+ def __add__(self, other: SelectInput) -> "SelectInputsManager":
854
+ """Backward compatibility: Support += operator for appending."""
855
+ self.append(other)
856
+ return self
857
+
858
+
859
+ class JoinInputsManager(SelectInputsManager):
860
+ """Manager for join-specific operations, extends SelectInputsManager."""
861
+
862
+ def __init__(self, join_inputs: JoinInputs):
863
+ super().__init__(join_inputs)
864
+ self.join_inputs = join_inputs
865
+
866
+ # === Query Methods ===
867
+
868
+ def get_join_key_selects(self) -> List[SelectInput]:
869
+ """Returns only the `SelectInput` objects that are marked as join keys."""
870
+ return [v for v in self.join_inputs.renames if v.join_key]
871
+
872
+ def get_join_key_renames(self, side: SideLit, filter_drop: bool = False) -> JoinKeyRenameResponse:
873
+ """Gets the temporary rename mapping for all join keys on one side of a join."""
874
+ join_key_selects = self.get_join_key_selects()
875
+ join_key_list = [
876
+ JoinKeyRename(jk.new_name, construct_join_key_name(side, jk.new_name))
877
+ for jk in join_key_selects
878
+ if jk.keep or not filter_drop
879
+ ]
880
+ return JoinKeyRenameResponse(side, join_key_list)
881
+
882
+ def get_join_key_rename_mapping(self, side: SideLit) -> Dict[str, str]:
883
+ """Returns a dictionary mapping original join key names to their temporary names."""
884
+ join_key_response = self.get_join_key_renames(side)
885
+ return {jkr.original_name: jkr.temp_name for jkr in join_key_response.join_key_renames}
886
+
887
+ @property
888
+ def join_key_selects(self) -> List[SelectInput]:
889
+ """Backward compatibility: Returns join key SelectInputs."""
890
+ return self.get_join_key_selects()
891
+
892
+
893
+ class JoinSelectManagerMixin:
894
+ """Mixin providing common methods for join-like operations."""
895
+
896
+ left_manager: JoinInputsManager
897
+ right_manager: JoinInputsManager
898
+ input: Union[CrossJoinInput, JoinInput, FuzzyMatchInput]
899
+
900
+ @staticmethod
901
+ def parse_select(select: Union[List[SelectInput], List[str], List[Dict], Dict]) -> JoinInputs:
902
+ """Parses various input formats into a standardized `JoinInputs` object."""
903
+ if not select:
904
+ return JoinInputs(renames=[])
905
+
906
+ if all(isinstance(c, SelectInput) for c in select):
907
+ return JoinInputs(renames=select)
908
+ elif all(isinstance(c, dict) for c in select):
909
+ return JoinInputs(renames=[SelectInput(**c) for c in select])
910
+ elif isinstance(select, dict):
911
+ renames = select.get('renames')
912
+ if renames:
913
+ return JoinInputs(renames=[SelectInput(**c) for c in renames])
914
+ return JoinInputs(renames=[])
915
+ elif all(isinstance(c, str) for c in select):
916
+ return JoinInputs(renames=[SelectInput(old_name=s, new_name=s) for s in select])
917
+
918
+ raise ValueError(f"Unable to parse select input: {type(select)}")
919
+
920
+ def get_overlapping_columns(self) -> Set[str]:
921
+ """Finds column names that would conflict after the join."""
922
+ return self.left_manager.get_new_cols() & self.right_manager.get_new_cols()
923
+
924
+ def auto_generate_new_col_name(self, old_col_name: str, side: str) -> str:
925
+ """Generates a new, non-conflicting column name by adding a suffix if necessary."""
926
+ current_names = self.get_overlapping_columns()
927
+ if old_col_name not in current_names:
928
+ return old_col_name
929
+
930
+ new_name = old_col_name
931
+ while new_name in current_names:
932
+ new_name = f'{side}_{new_name}'
933
+ return new_name
934
+
935
+ def add_new_select_column(self, select_input: SelectInput, side: str) -> None:
936
+ """Adds a new column to the selection for either the left or right side."""
937
+ target_input = self.input.right_select if side == 'right' else self.input.left_select
938
+
939
+ select_input.new_name = self.auto_generate_new_col_name(
940
+ select_input.old_name, side=side
941
+ )
942
+
943
+ target_input.renames.append(select_input)
944
+
945
+
946
+ class CrossJoinInputManager(JoinSelectManagerMixin):
947
+ """Manager for cross join operations."""
948
+
949
+ def __init__(self, cross_join_input: CrossJoinInput):
950
+ self.input = deepcopy(cross_join_input)
951
+ self.left_manager = JoinInputsManager(self.input.left_select)
952
+ self.right_manager = JoinInputsManager(self.input.right_select)
953
+
954
+ @classmethod
955
+ def create(cls, left_select: Union[List[SelectInput], List[str]],
956
+ right_select: Union[List[SelectInput], List[str]]) -> "CrossJoinInputManager":
957
+ """Factory method to create CrossJoinInput from various input formats."""
958
+ left_inputs = cls.parse_select(left_select)
959
+ right_inputs = cls.parse_select(right_select)
960
+
961
+ cross_join = CrossJoinInput(
962
+ left_select=left_inputs,
963
+ right_select=right_inputs
964
+ )
965
+ return cls(cross_join)
966
+
967
+ def get_overlapping_records(self) -> Set[str]:
968
+ """Finds column names that would conflict after the join."""
969
+ return self.get_overlapping_columns()
970
+
971
+ def auto_rename(self, rename_mode: Literal["suffix", "prefix"] = "prefix") -> None:
972
+ """Automatically renames columns on the right side to prevent naming conflicts."""
973
+ overlapping_records = self.get_overlapping_records()
974
+
975
+ while len(overlapping_records) > 0:
976
+ for right_col in self.input.right_select.renames:
977
+ if right_col.new_name in overlapping_records:
978
+ if rename_mode == "prefix":
979
+ right_col.new_name = 'right_' + right_col.new_name
980
+ elif rename_mode == "suffix":
981
+ right_col.new_name = right_col.new_name + '_right'
982
+ else:
983
+ raise ValueError(f'Unknown rename_mode: {rename_mode}')
984
+ overlapping_records = self.get_overlapping_records()
985
+
986
+ # === Backward Compatibility Properties ===
987
+
988
+ @property
989
+ def left_select(self) -> JoinInputsManager:
990
+ """Backward compatibility: Access left_manager as left_select."""
991
+ return self.left_manager
992
+
993
+ @property
994
+ def right_select(self) -> JoinInputsManager:
995
+ """Backward compatibility: Access right_manager as right_select."""
996
+ return self.right_manager
997
+
998
+ @property
999
+ def overlapping_records(self) -> Set[str]:
1000
+ """Backward compatibility: Returns overlapping column names."""
1001
+ return self.get_overlapping_records()
1002
+
1003
+ def to_cross_join_input(self) -> CrossJoinInput:
1004
+ """Creates a new CrossJoinInput instance based on the current manager settings.
1005
+
1006
+ This is useful when you've modified the manager (e.g., via auto_rename) and
1007
+ want to get a fresh CrossJoinInput with all the current settings applied.
1008
+
1009
+ Returns:
1010
+ A new CrossJoinInput instance with current settings
1011
+ """
1012
+ return CrossJoinInput(
1013
+ left_select=JoinInputs(renames=self.input.left_select.renames.copy()),
1014
+ right_select=JoinInputs(renames=self.input.right_select.renames.copy())
1015
+ )
1016
+
1017
+
1018
+ class JoinInputManager(JoinSelectManagerMixin):
1019
+ """Manager for standard SQL-style join operations."""
1020
+
1021
+ def __init__(self, join_input: JoinInput):
1022
+ self.input = deepcopy(join_input)
1023
+ self.left_manager = JoinInputsManager(self.input.left_select)
1024
+ self.right_manager = JoinInputsManager(self.input.right_select)
1025
+ self.set_join_keys()
1026
+
1027
+ @classmethod
1028
+ def create(cls, join_mapping: Union[List[JoinMap], Tuple[str, str], str],
1029
+ left_select: Union[List[SelectInput], List[str]],
1030
+ right_select: Union[List[SelectInput], List[str]],
1031
+ how: JoinStrategy = 'inner') -> "JoinInputManager":
1032
+ """Factory method to create JoinInput from various input formats."""
1033
+ # Use JoinInput's own create method for parsing
1034
+ join_input = JoinInput(
1035
+ join_mapping=join_mapping,
1036
+ left_select=left_select,
1037
+ right_select=right_select,
1038
+ how=how
1039
+ )
1040
+
1041
+ manager = cls(join_input)
1042
+ manager.set_join_keys()
1043
+ return manager
1044
+
1045
+ def set_join_keys(self) -> None:
1046
+ """Marks the `SelectInput` objects corresponding to join keys."""
1047
+ left_join_keys = self._get_left_join_keys_set()
1048
+ right_join_keys = self._get_right_join_keys_set()
1049
+
1050
+ for select_input in self.input.left_select.renames:
1051
+ select_input.join_key = select_input.old_name in left_join_keys
1052
+
1053
+ for select_input in self.input.right_select.renames:
1054
+ select_input.join_key = select_input.old_name in right_join_keys
1055
+
1056
+ def _get_left_join_keys_set(self) -> Set[str]:
1057
+ """Internal: Returns a set of the left-side join key column names."""
1058
+ return {jm.left_col for jm in self.input.join_mapping}
1059
+
1060
+ def _get_right_join_keys_set(self) -> Set[str]:
1061
+ """Internal: Returns a set of the right-side join key column names."""
1062
+ return {jm.right_col for jm in self.input.join_mapping}
1063
+
1064
+ def get_left_join_keys(self) -> Set[str]:
1065
+ """Returns a set of the left-side join key column names."""
1066
+ return self._get_left_join_keys_set()
1067
+
1068
+ def get_right_join_keys(self) -> Set[str]:
1069
+ """Returns a set of the right-side join key column names."""
1070
+ return self._get_right_join_keys_set()
1071
+
1072
+ def get_left_join_keys_list(self) -> List[str]:
1073
+ """Returns an ordered list of the left-side join key column names."""
1074
+ return [jm.left_col for jm in self.used_join_mapping]
1075
+
1076
+ def get_right_join_keys_list(self) -> List[str]:
1077
+ """Returns an ordered list of the right-side join key column names."""
1078
+ return [jm.right_col for jm in self.used_join_mapping]
1079
+
1080
+ def get_overlapping_records(self) -> Set[str]:
1081
+ """Finds column names that would conflict after the join."""
1082
+ return self.get_overlapping_columns()
1083
+
1084
+ def auto_rename(self) -> None:
1085
+ """Automatically renames columns on the right side to prevent naming conflicts."""
1086
+ self.set_join_keys()
1087
+ overlapping_records = self.get_overlapping_records()
1088
+
1089
+ while len(overlapping_records) > 0:
1090
+ for right_col in self.input.right_select.renames:
1091
+ if right_col.new_name in overlapping_records:
1092
+ right_col.new_name = right_col.new_name + '_right'
1093
+ overlapping_records = self.get_overlapping_records()
1094
+
1095
+ def get_join_key_renames(self, filter_drop: bool = False) -> FullJoinKeyResponse:
1096
+ """Gets the temporary rename mappings for the join keys on both sides."""
1097
+ left_renames = self.left_manager.get_join_key_renames(side="left", filter_drop=filter_drop)
1098
+ right_renames = self.right_manager.get_join_key_renames(side="right", filter_drop=filter_drop)
1099
+ return FullJoinKeyResponse(left_renames, right_renames)
1100
+
1101
+ def get_names_for_table_rename(self) -> List[JoinMap]:
1102
+ """Gets join mapping with renamed columns applied."""
1103
+ new_mappings: List[JoinMap] = []
1104
+ left_rename_table = self.left_manager.get_rename_table()
1105
+ right_rename_table = self.right_manager.get_rename_table()
1106
+
1107
+ for join_map in self.input.join_mapping:
1108
+ new_left = left_rename_table.get(join_map.left_col, join_map.left_col)
1109
+ new_right = right_rename_table.get(join_map.right_col, join_map.right_col)
1110
+ new_mappings.append(JoinMap(left_col=new_left, right_col=new_right))
1111
+
1112
+ return new_mappings
1113
+
1114
+ def get_used_join_mapping(self) -> List[JoinMap]:
1115
+ """Returns the final join mapping after applying all renames and transformations."""
1116
+ new_mappings: List[JoinMap] = []
1117
+ left_rename_table = self.left_manager.get_rename_table()
1118
+ right_rename_table = self.right_manager.get_rename_table()
1119
+ left_join_rename_mapping = self.left_manager.get_join_key_rename_mapping("left")
1120
+ right_join_rename_mapping = self.right_manager.get_join_key_rename_mapping("right")
1121
+ for join_map in self.input.join_mapping:
1122
+ left_col = left_rename_table.get(join_map.left_col, join_map.left_col)
1123
+ right_col = right_rename_table.get(join_map.right_col, join_map.left_col)
1124
+
1125
+ final_left = left_join_rename_mapping.get(left_col, None)
1126
+ final_right = right_join_rename_mapping.get(right_col, None)
1127
+
1128
+ new_mappings.append(JoinMap(left_col=final_left, right_col=final_right))
1129
+
1130
+ return new_mappings
1131
+
1132
+ def to_join_input(self) -> JoinInput:
1133
+ """Creates a new JoinInput instance based on the current manager settings.
1134
+
1135
+ This is useful when you've modified the manager (e.g., via auto_rename) and
1136
+ want to get a fresh JoinInput with all the current settings applied.
1137
+
1138
+ Returns:
1139
+ A new JoinInput instance with current settings
1140
+ """
1141
+ return JoinInput(
1142
+ join_mapping=self.input.join_mapping,
1143
+ left_select=JoinInputs(renames=self.input.left_select.renames.copy()),
1144
+ right_select=JoinInputs(renames=self.input.right_select.renames.copy()),
1145
+ how=self.input.how
1146
+ )
1147
+
1148
+ @property
1149
+ def left_select(self) -> JoinInputsManager:
1150
+ """Backward compatibility: Access left_manager as left_select.
1151
+
1152
+ This returns the MANAGER, not the data model.
1153
+ Usage: manager.left_select.join_key_selects
1154
+ """
1155
+ return self.left_manager
1156
+
1157
+ @property
1158
+ def right_select(self) -> JoinInputsManager:
1159
+ """Backward compatibility: Access right_manager as right_select.
1160
+
1161
+ This returns the MANAGER, not the data model.
1162
+ Usage: manager.right_select.join_key_selects
1163
+ """
1164
+ return self.right_manager
1165
+
1166
+ @property
1167
+ def how(self) -> JoinStrategy:
1168
+ """Backward compatibility: Access join strategy."""
1169
+ return self.input.how
1170
+
1171
+ @property
1172
+ def join_mapping(self) -> List[JoinMap]:
1173
+ """Backward compatibility: Access join mapping."""
1174
+ return self.input.join_mapping
1175
+
1176
+ @property
1177
+ def overlapping_records(self) -> Set[str]:
1178
+ """Backward compatibility: Returns overlapping column names."""
1179
+ return self.get_overlapping_records()
1180
+
1181
+ @property
1182
+ def used_join_mapping(self) -> List[JoinMap]:
1183
+ """Backward compatibility: Returns used join mapping.
1184
+
1185
+ This property is critical - it's used by left_join_keys and right_join_keys.
1186
+ """
1187
+ return self.get_used_join_mapping()
1188
+
1189
+ @property
1190
+ def left_join_keys(self) -> List[str]:
1191
+ """Backward compatibility: Returns left join keys list.
1192
+
1193
+ IMPORTANT: Uses the used_join_mapping PROPERTY (not method).
1194
+ """
1195
+ return [jm.left_col for jm in self.used_join_mapping]
1196
+
1197
+ @property
1198
+ def right_join_keys(self) -> List[str]:
1199
+ """Backward compatibility: Returns right join keys list.
1200
+
1201
+ IMPORTANT: Uses the used_join_mapping PROPERTY (not method).
1202
+ """
1203
+ return [jm.right_col for jm in self.used_join_mapping]
1204
+
1205
+ @property
1206
+ def _left_join_keys(self) -> Set[str]:
1207
+ """Backward compatibility: Private property for left join key set."""
1208
+ return self._get_left_join_keys_set()
1209
+
1210
+ @property
1211
+ def _right_join_keys(self) -> Set[str]:
1212
+ """Backward compatibility: Private property for right join key set."""
1213
+ return self._get_right_join_keys_set()
1214
+
1215
+
1216
+ class FuzzyMatchInputManager(JoinInputManager):
1217
+ """Manager for fuzzy matching join operations."""
1218
+
1219
+ def __init__(self, fuzzy_input: FuzzyMatchInput):
1220
+ self.fuzzy_input = deepcopy(fuzzy_input)
1221
+ super().__init__(JoinInput(
1222
+ join_mapping=[JoinMap(left_col=fm.left_col, right_col=fm.right_col)
1223
+ for fm in self.fuzzy_input.join_mapping],
1224
+ left_select=self.fuzzy_input.left_select,
1225
+ right_select=self.fuzzy_input.right_select,
1226
+ how=self.fuzzy_input.how
1227
+ ))
1228
+
1229
+ @classmethod
1230
+ def create(cls, join_mapping: Union[List[FuzzyMapping], Tuple[str, str], str],
1231
+ left_select: Union[List[SelectInput], List[str]],
1232
+ right_select: Union[List[SelectInput], List[str]],
1233
+ aggregate_output: bool = False,
1234
+ how: JoinStrategy = 'inner') -> "FuzzyMatchInputManager":
1235
+ """Factory method to create FuzzyMatchInput from various input formats."""
1236
+ parsed_mapping = cls.parse_fuzz_mapping(join_mapping)
1237
+ left_inputs = cls.parse_select(left_select)
1238
+ right_inputs = cls.parse_select(right_select)
1239
+
1240
+ fuzzy_input = FuzzyMatchInput(
1241
+ join_mapping=parsed_mapping,
1242
+ left_select=left_inputs,
1243
+ right_select=right_inputs,
1244
+ how=how,
1245
+ aggregate_output=aggregate_output
1246
+ )
1247
+
1248
+ manager = cls(fuzzy_input)
1249
+
1250
+ right_old_names = {v.old_name for v in fuzzy_input.right_select.renames}
1251
+ left_old_names = {v.old_name for v in fuzzy_input.left_select.renames}
1252
+
1253
+ for jm in parsed_mapping:
1254
+ if jm.right_col not in right_old_names:
1255
+ manager.right_manager.append(
1256
+ SelectInput(old_name=jm.right_col, keep=False, join_key=True)
1257
+ )
1258
+ if jm.left_col not in left_old_names:
1259
+ manager.left_manager.append(
1260
+ SelectInput(old_name=jm.left_col, keep=False, join_key=True)
1261
+ )
1262
+
1263
+ manager.set_join_keys()
1264
+ return manager
1265
+
1266
+ @staticmethod
1267
+ def parse_fuzz_mapping(fuzz_mapping: Union[List[FuzzyMapping], Tuple[str, str],
1268
+ str, FuzzyMapping, List[Dict]]) -> List[FuzzyMapping]:
1269
+ """Parses various input formats into a list of FuzzyMapping objects."""
1270
+ if isinstance(fuzz_mapping, (tuple, list)):
1271
+ if len(fuzz_mapping) == 0:
1272
+ raise ValueError("Fuzzy mapping cannot be empty")
1273
+
1274
+ if all(isinstance(fm, dict) for fm in fuzz_mapping):
1275
+ return [FuzzyMapping(**fm) for fm in fuzz_mapping]
1276
+
1277
+ if all(isinstance(fm, FuzzyMapping) for fm in fuzz_mapping):
1278
+ return fuzz_mapping
1279
+
1280
+ if len(fuzz_mapping) <= 2:
1281
+ if len(fuzz_mapping) == 2:
1282
+ if isinstance(fuzz_mapping[0], str) and isinstance(fuzz_mapping[1], str):
1283
+ return [FuzzyMapping(left_col=fuzz_mapping[0], right_col=fuzz_mapping[1])]
1284
+ elif len(fuzz_mapping) == 1 and isinstance(fuzz_mapping[0], str):
1285
+ return [FuzzyMapping(left_col=fuzz_mapping[0], right_col=fuzz_mapping[0])]
1286
+
1287
+ elif isinstance(fuzz_mapping, str):
1288
+ return [FuzzyMapping(left_col=fuzz_mapping, right_col=fuzz_mapping)]
1289
+
1290
+ elif isinstance(fuzz_mapping, FuzzyMapping):
1291
+ return [fuzz_mapping]
1292
+
1293
+ raise ValueError(f'No valid fuzzy mapping as input: {type(fuzz_mapping)}')
1294
+
1295
+ def get_fuzzy_maps(self) -> List[FuzzyMapping]:
1296
+ """Returns the final fuzzy mappings after applying all column renames."""
1297
+ new_mappings = []
1298
+ left_rename_table = self.left_manager.get_rename_table()
1299
+ right_rename_table = self.right_manager.get_rename_table()
1300
+
1301
+ for org_fuzzy_map in self.fuzzy_input.join_mapping:
1302
+ right_col = right_rename_table.get(org_fuzzy_map.right_col, org_fuzzy_map.right_col)
1303
+ left_col = left_rename_table.get(org_fuzzy_map.left_col, org_fuzzy_map.left_col)
1304
+
1305
+ if right_col != org_fuzzy_map.right_col or left_col != org_fuzzy_map.left_col:
1306
+ new_mapping = deepcopy(org_fuzzy_map)
1307
+ new_mapping.left_col = left_col
1308
+ new_mapping.right_col = right_col
1309
+ new_mappings.append(new_mapping)
1310
+ else:
1311
+ new_mappings.append(org_fuzzy_map)
1312
+
1313
+ return new_mappings
1314
+
1315
+ # === Backward Compatibility Properties ===
1316
+
1317
+ @property
1318
+ def fuzzy_maps(self) -> List[FuzzyMapping]:
1319
+ """Backward compatibility: Returns fuzzy mappings."""
1320
+ return self.get_fuzzy_maps()
1321
+
1322
+ @property
1323
+ def join_mapping(self) -> List[FuzzyMapping]:
1324
+ """Backward compatibility: Access fuzzy join mapping."""
1325
+ return self.get_fuzzy_maps()
1326
+
1327
+ @property
1328
+ def aggregate_output(self) -> bool:
1329
+ """Backward compatibility: Access aggregate_output setting."""
1330
+ return self.fuzzy_input.aggregate_output
1331
+
1332
+ def to_fuzzy_match_input(self) -> FuzzyMatchInput:
1333
+ """Creates a new FuzzyMatchInput instance based on the current manager settings.
1334
+
1335
+ This is useful when you've modified the manager (e.g., via auto_rename) and
1336
+ want to get a fresh FuzzyMatchInput with all the current settings applied.
1337
+
1338
+ Returns:
1339
+ A new FuzzyMatchInput instance with current settings
1340
+ """
1341
+ return FuzzyMatchInput(
1342
+ join_mapping=self.fuzzy_input.join_mapping,
1343
+ left_select=JoinInputs(renames=self.input.left_select.renames.copy()),
1344
+ right_select=JoinInputs(renames=self.input.right_select.renames.copy()),
1345
+ how=self.fuzzy_input.how,
1346
+ aggregate_output=self.fuzzy_input.aggregate_output
1347
+ )