modaic 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of modaic might be problematic. Click here for more details.
- modaic/__init__.py +25 -0
- modaic/agents/rag_agent.py +33 -0
- modaic/agents/registry.py +84 -0
- modaic/auto_agent.py +228 -0
- modaic/context/__init__.py +34 -0
- modaic/context/base.py +1064 -0
- modaic/context/dtype_mapping.py +25 -0
- modaic/context/table.py +585 -0
- modaic/context/text.py +94 -0
- modaic/databases/__init__.py +35 -0
- modaic/databases/graph_database.py +269 -0
- modaic/databases/sql_database.py +355 -0
- modaic/databases/vector_database/__init__.py +12 -0
- modaic/databases/vector_database/benchmarks/baseline.py +123 -0
- modaic/databases/vector_database/benchmarks/common.py +48 -0
- modaic/databases/vector_database/benchmarks/fork.py +132 -0
- modaic/databases/vector_database/benchmarks/threaded.py +119 -0
- modaic/databases/vector_database/vector_database.py +722 -0
- modaic/databases/vector_database/vendors/milvus.py +408 -0
- modaic/databases/vector_database/vendors/mongodb.py +0 -0
- modaic/databases/vector_database/vendors/pinecone.py +0 -0
- modaic/databases/vector_database/vendors/qdrant.py +1 -0
- modaic/exceptions.py +38 -0
- modaic/hub.py +305 -0
- modaic/indexing.py +127 -0
- modaic/module_utils.py +341 -0
- modaic/observability.py +275 -0
- modaic/precompiled.py +429 -0
- modaic/query_language.py +321 -0
- modaic/storage/__init__.py +3 -0
- modaic/storage/file_store.py +239 -0
- modaic/storage/pickle_store.py +25 -0
- modaic/types.py +287 -0
- modaic/utils.py +21 -0
- modaic-0.1.0.dist-info/METADATA +281 -0
- modaic-0.1.0.dist-info/RECORD +39 -0
- modaic-0.1.0.dist-info/WHEEL +5 -0
- modaic-0.1.0.dist-info/licenses/LICENSE +31 -0
- modaic-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
INTEGER_DTYPE_MAPPING = {
|
|
4
|
+
np.int8: "TINYINT",
|
|
5
|
+
np.int16: "SMALLINT",
|
|
6
|
+
np.int32: "INT",
|
|
7
|
+
np.int64: "BIGINT",
|
|
8
|
+
}
|
|
9
|
+
|
|
10
|
+
SPECIAL_INTEGER_DTYPE_MAPPING = {"Int64": "BIGINT", "UInt64": "BIGINT UNSIGNED"}
|
|
11
|
+
|
|
12
|
+
FLOAT_DTYPE_MAPPING = {
|
|
13
|
+
np.float16: "FLOAT",
|
|
14
|
+
np.float32: "FLOAT",
|
|
15
|
+
np.float64: "DOUBLE",
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
OTHER_DTYPE_MAPPING = {
|
|
19
|
+
"boolean": "BOOLEAN",
|
|
20
|
+
"datetime": "DATETIME",
|
|
21
|
+
"timedelta": "TIME",
|
|
22
|
+
"string": "VARCHAR(255)",
|
|
23
|
+
"category": "VARCHAR(255)",
|
|
24
|
+
"default": "TEXT",
|
|
25
|
+
}
|
modaic/context/table.py
ADDED
|
@@ -0,0 +1,585 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import random
|
|
3
|
+
import re
|
|
4
|
+
import warnings
|
|
5
|
+
from abc import ABC
|
|
6
|
+
from contextlib import contextmanager
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import IO, Any, Dict, List, Literal, Optional, Set
|
|
9
|
+
|
|
10
|
+
import duckdb
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pandas as pd
|
|
13
|
+
from pydantic import PrivateAttr, ValidatorFunctionWrapHandler, field_validator, model_validator
|
|
14
|
+
|
|
15
|
+
from modaic.context.base import Context, HydratedAttr, requires_hydration
|
|
16
|
+
from modaic.types import Field
|
|
17
|
+
|
|
18
|
+
from ..storage.file_store import FileStore
|
|
19
|
+
from .dtype_mapping import (
|
|
20
|
+
FLOAT_DTYPE_MAPPING,
|
|
21
|
+
INTEGER_DTYPE_MAPPING,
|
|
22
|
+
OTHER_DTYPE_MAPPING,
|
|
23
|
+
SPECIAL_INTEGER_DTYPE_MAPPING,
|
|
24
|
+
)
|
|
25
|
+
from .text import Text
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class BaseTable(Context, ABC):
|
|
29
|
+
name: str
|
|
30
|
+
_df: pd.DataFrame = PrivateAttr()
|
|
31
|
+
|
|
32
|
+
@field_validator("name", mode="before")
|
|
33
|
+
@classmethod
|
|
34
|
+
def sanitize_name(cls, name: str) -> str:
|
|
35
|
+
return sanitize_name(name)
|
|
36
|
+
|
|
37
|
+
def column_samples(self, col: str) -> list[Any]:
|
|
38
|
+
"""
|
|
39
|
+
Return up to 3 distinct sample values from the given column.
|
|
40
|
+
|
|
41
|
+
Picks at most three unique, non-null, short (<64 chars) values from
|
|
42
|
+
the column, favoring speed by sampling after de-duplicating values.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
col: Column name to sample from.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
A list with up to three JSON-serializable sample values.
|
|
49
|
+
"""
|
|
50
|
+
# TODO look up columnn
|
|
51
|
+
series = self._df[col]
|
|
52
|
+
|
|
53
|
+
valid_values = [x for x in series.dropna().unique() if pd.notnull(x) and len(str(x)) < 64]
|
|
54
|
+
sample_values = random.sample(valid_values, min(3, len(valid_values)))
|
|
55
|
+
|
|
56
|
+
# Convert numpy types to Python native types for JSON serialization
|
|
57
|
+
converted_values = []
|
|
58
|
+
for val in sample_values:
|
|
59
|
+
if hasattr(val, "item"): # numpy types have .item() method
|
|
60
|
+
converted_values.append(val.item())
|
|
61
|
+
else:
|
|
62
|
+
converted_values.append(val)
|
|
63
|
+
|
|
64
|
+
return converted_values if converted_values else []
|
|
65
|
+
|
|
66
|
+
def get_col(self, col_name: str) -> pd.Series:
|
|
67
|
+
"""
|
|
68
|
+
Gets a single column from the table.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
col_name: Name of the column to get
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
The specified column as a pandas Series.
|
|
75
|
+
"""
|
|
76
|
+
return self._df[col_name]
|
|
77
|
+
|
|
78
|
+
def schema_info(self) -> dict[str, Any]:
|
|
79
|
+
column_dict = {}
|
|
80
|
+
for col in self._df.columns:
|
|
81
|
+
if isinstance(self._df[col], pd.DataFrame):
|
|
82
|
+
raise ValueError(f"Column {col} is a DataFrame, which is not supported.")
|
|
83
|
+
column_dict[col] = {
|
|
84
|
+
"type": _pandas_to_mysql_dtype(self._df[col].dtype),
|
|
85
|
+
"sample_values": self.column_samples(col),
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
schema_dict = {"table_name": self.name, "column_dict": column_dict}
|
|
89
|
+
|
|
90
|
+
return schema_dict
|
|
91
|
+
|
|
92
|
+
def query(self, query: str) -> pd.DataFrame:
|
|
93
|
+
"""
|
|
94
|
+
Queries the table using DuckDB SQL.
|
|
95
|
+
|
|
96
|
+
Notes:
|
|
97
|
+
- Refer to the in-memory table as `this` (alias `This`).
|
|
98
|
+
|
|
99
|
+
Example:
|
|
100
|
+
```python
|
|
101
|
+
# Select a few rows
|
|
102
|
+
df = table.query("SELECT * FROM this LIMIT 5")
|
|
103
|
+
|
|
104
|
+
# Aggregate over a column
|
|
105
|
+
df = table.query("SELECT category, COUNT(*) AS n FROM this GROUP BY category")
|
|
106
|
+
```
|
|
107
|
+
"""
|
|
108
|
+
return duckdb.query_df(self._df, self.name, query).to_df()
|
|
109
|
+
|
|
110
|
+
def embedme(self) -> str:
|
|
111
|
+
"""
|
|
112
|
+
embedme method for table. Returns a markdown representation of the table.
|
|
113
|
+
"""
|
|
114
|
+
return self.markdown()
|
|
115
|
+
|
|
116
|
+
def markdown(self) -> str: # TODO: add example
|
|
117
|
+
"""
|
|
118
|
+
Converts the table to markdown format.
|
|
119
|
+
Returns a markdown representation of the table with the table name as header.
|
|
120
|
+
"""
|
|
121
|
+
content = ""
|
|
122
|
+
content += f"Table name: {self.name}\n"
|
|
123
|
+
|
|
124
|
+
# Add header row
|
|
125
|
+
columns = [str(col) for col in self._df.columns]
|
|
126
|
+
content += "| " + " | ".join(columns) + " |\n"
|
|
127
|
+
|
|
128
|
+
# Add header separator
|
|
129
|
+
content += "| " + " | ".join(["---"] * len(columns)) + " |\n"
|
|
130
|
+
|
|
131
|
+
# Add data rows
|
|
132
|
+
for _, row in self._df.iterrows():
|
|
133
|
+
row_values = []
|
|
134
|
+
for value in row:
|
|
135
|
+
if pd.isna(value) or value is None:
|
|
136
|
+
row_values.append("")
|
|
137
|
+
else:
|
|
138
|
+
row_values.append(str(value))
|
|
139
|
+
content += "| " + " | ".join(row_values) + " |\n"
|
|
140
|
+
|
|
141
|
+
return content
|
|
142
|
+
|
|
143
|
+
def to_text(self) -> Text:
|
|
144
|
+
"""
|
|
145
|
+
Converts the table to markdown and returns a Text context object.
|
|
146
|
+
"""
|
|
147
|
+
return Text(self.markdown())
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class Table(BaseTable):
|
|
151
|
+
content: str = ""
|
|
152
|
+
df: List[Dict[str, Any]] = Field(hidden=True)
|
|
153
|
+
_df: pd.DataFrame = PrivateAttr()
|
|
154
|
+
|
|
155
|
+
@model_validator(mode="wrap")
|
|
156
|
+
@classmethod
|
|
157
|
+
def truncate(cls, data: Any, handler: ValidatorFunctionWrapHandler) -> "Table":
|
|
158
|
+
df = data["df"]
|
|
159
|
+
if isinstance(df, pd.DataFrame):
|
|
160
|
+
serialized_df = df.to_dict(orient="records")
|
|
161
|
+
pd_df = df
|
|
162
|
+
else:
|
|
163
|
+
serialized_df = df
|
|
164
|
+
pd_df = pd.DataFrame(serialized_df)
|
|
165
|
+
data["df"] = serialized_df
|
|
166
|
+
self = handler(data)
|
|
167
|
+
self._df = pd_df
|
|
168
|
+
return self
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
class TableFile(BaseTable):
|
|
172
|
+
"""
|
|
173
|
+
A Context object to represent table documents such as excel, csv and tsv files.
|
|
174
|
+
"""
|
|
175
|
+
|
|
176
|
+
file_ref: str
|
|
177
|
+
file_type: Literal["xls", "xlsx", "csv", "tsv"]
|
|
178
|
+
sheet_name: Optional[str] = None
|
|
179
|
+
_df: pd.DataFrame = HydratedAttr()
|
|
180
|
+
|
|
181
|
+
@classmethod
|
|
182
|
+
def from_file(
|
|
183
|
+
cls,
|
|
184
|
+
file_ref: str,
|
|
185
|
+
file: Path | IO,
|
|
186
|
+
file_type: Literal["xls", "xlsx", "csv", "tsv"] = "xls",
|
|
187
|
+
name: Optional[str] = None,
|
|
188
|
+
sheet_name: Optional[str] = None,
|
|
189
|
+
**kwargs,
|
|
190
|
+
) -> "TableFile":
|
|
191
|
+
# NOTE: Always set a sheet name for excel files
|
|
192
|
+
if file_type in ["xls", "xlsx"] and sheet_name is None:
|
|
193
|
+
xls = pd.ExcelFile(file)
|
|
194
|
+
sheet_name = xls.sheet_names[0]
|
|
195
|
+
# NOTE: ensure we always have a semantic name for the table
|
|
196
|
+
if name is None and file_type in ["xls", "xlsx"]:
|
|
197
|
+
name = sheet_name
|
|
198
|
+
elif name is None:
|
|
199
|
+
name = file_ref.split("/")[-1].split(".")[0]
|
|
200
|
+
instance = cls(
|
|
201
|
+
name=name,
|
|
202
|
+
file_ref=file_ref,
|
|
203
|
+
file_type=file_type,
|
|
204
|
+
sheet_name=sheet_name,
|
|
205
|
+
**kwargs,
|
|
206
|
+
)
|
|
207
|
+
instance._hydrate_from_file(file)
|
|
208
|
+
return instance
|
|
209
|
+
|
|
210
|
+
@classmethod
|
|
211
|
+
def from_file_store(cls, file_ref: str, file_store: FileStore, **kwargs) -> "TableFile":
|
|
212
|
+
file_result = file_store.get(file_ref)
|
|
213
|
+
if "sheet_name" in file_result.metadata:
|
|
214
|
+
sheet_name = file_result.metadata["sheet_name"]
|
|
215
|
+
else:
|
|
216
|
+
sheet_name = None
|
|
217
|
+
return cls.from_file(
|
|
218
|
+
file_ref,
|
|
219
|
+
file_result.file,
|
|
220
|
+
file_result.type,
|
|
221
|
+
name=file_result.name,
|
|
222
|
+
sheet_name=sheet_name,
|
|
223
|
+
**kwargs,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
def hydrate(self, file_store: FileStore):
|
|
227
|
+
file = file_store.get(self.file_ref)
|
|
228
|
+
self._hydrate_from_file(file)
|
|
229
|
+
|
|
230
|
+
def _hydrate_from_file(self, file: Path | IO):
|
|
231
|
+
if self.file_type in ["excel", "xlsx"]:
|
|
232
|
+
if self.sheet_name is None:
|
|
233
|
+
df = pd.read_excel(file)
|
|
234
|
+
else:
|
|
235
|
+
df = pd.read_excel(file, sheet_name=self.sheet_name)
|
|
236
|
+
elif self.file_type == "csv":
|
|
237
|
+
df = pd.read_csv(file)
|
|
238
|
+
elif self.file_type == "tsv":
|
|
239
|
+
df = pd.read_csv(file, sep="\t")
|
|
240
|
+
else:
|
|
241
|
+
raise ValueError(f"TableFile got unsupported file type: {self.file_type}")
|
|
242
|
+
self._df = _process_df(df)
|
|
243
|
+
|
|
244
|
+
@requires_hydration
|
|
245
|
+
def column_samples(self, col: str) -> list[Any]:
|
|
246
|
+
"""
|
|
247
|
+
Returns up to 3 distinct sample values from the given column.
|
|
248
|
+
"""
|
|
249
|
+
return super().column_samples(col)
|
|
250
|
+
|
|
251
|
+
@requires_hydration
|
|
252
|
+
def get_col(self, col_name: str) -> pd.Series:
|
|
253
|
+
"""
|
|
254
|
+
Gets a single column from the table.
|
|
255
|
+
"""
|
|
256
|
+
return super().get_col(col_name)
|
|
257
|
+
|
|
258
|
+
@requires_hydration
|
|
259
|
+
def schema_info(self) -> dict[str, Any]:
|
|
260
|
+
"""
|
|
261
|
+
Returns the schema information of the table.
|
|
262
|
+
"""
|
|
263
|
+
return super().schema_info()
|
|
264
|
+
|
|
265
|
+
@requires_hydration
|
|
266
|
+
def query(self, query: str) -> pd.DataFrame:
|
|
267
|
+
"""
|
|
268
|
+
Queries the table using DuckDB SQL.
|
|
269
|
+
"""
|
|
270
|
+
return super().query(query)
|
|
271
|
+
|
|
272
|
+
@requires_hydration
|
|
273
|
+
def embedme(self) -> str:
|
|
274
|
+
"""
|
|
275
|
+
Converts the table to markdown and returns a Text context object.
|
|
276
|
+
"""
|
|
277
|
+
return super().embedme()
|
|
278
|
+
|
|
279
|
+
@requires_hydration
|
|
280
|
+
def markdown(self) -> str:
|
|
281
|
+
"""
|
|
282
|
+
Converts the table to markdown format.
|
|
283
|
+
"""
|
|
284
|
+
return super().markdown()
|
|
285
|
+
|
|
286
|
+
@requires_hydration
|
|
287
|
+
def to_text(self) -> Text:
|
|
288
|
+
"""
|
|
289
|
+
Converts the table to markdown and returns a Text context object.
|
|
290
|
+
"""
|
|
291
|
+
return super().to_text()
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
class BaseTabbedTable(Context):
|
|
295
|
+
names: Set[str]
|
|
296
|
+
_tables: Optional[Dict[str, pd.DataFrame]] = PrivateAttr()
|
|
297
|
+
_sql_db: Optional[duckdb.DuckDBPyConnection] = PrivateAttr()
|
|
298
|
+
|
|
299
|
+
def init_sql(self):
|
|
300
|
+
"""
|
|
301
|
+
Initilizes and in memory sql database for querying
|
|
302
|
+
"""
|
|
303
|
+
self._sql_db = duckdb.connect(database=":memory:")
|
|
304
|
+
for table_name, table in self._tables.items():
|
|
305
|
+
self._sql_db.register(table_name, table.df)
|
|
306
|
+
|
|
307
|
+
def close_sql(self):
|
|
308
|
+
"""
|
|
309
|
+
Closes the in memory sql database
|
|
310
|
+
"""
|
|
311
|
+
self._sql_db.close()
|
|
312
|
+
self._sql_db = None
|
|
313
|
+
|
|
314
|
+
@contextmanager
|
|
315
|
+
def sql(self):
|
|
316
|
+
self.init_sql()
|
|
317
|
+
yield self._sql_db
|
|
318
|
+
self.close_sql()
|
|
319
|
+
|
|
320
|
+
def query(self, query: str) -> Table:
|
|
321
|
+
"""
|
|
322
|
+
Queries the in memory sql database
|
|
323
|
+
"""
|
|
324
|
+
if self._sql_db is None:
|
|
325
|
+
raise ValueError(
|
|
326
|
+
"Attempted to run query on MultiTabbedTable without initializing the SQL database. Use with `with MultiTabbedTable.sql():` or `MultiTabbedTable.init_sql()`"
|
|
327
|
+
)
|
|
328
|
+
try:
|
|
329
|
+
df = self._sql_db.execute(query).fetchdf()
|
|
330
|
+
return Table(df=df, name="query_result")
|
|
331
|
+
except Exception as e:
|
|
332
|
+
raise ValueError("Error querying SQL database") from e
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
class TabbedTable(BaseTabbedTable):
|
|
336
|
+
tables: Dict[str, List[Dict[str, Any]]]
|
|
337
|
+
_tables: Dict[str, pd.DataFrame] = PrivateAttr()
|
|
338
|
+
_sql_db: duckdb.DuckDBPyConnection = PrivateAttr()
|
|
339
|
+
|
|
340
|
+
@field_validator("tables", mode="before")
|
|
341
|
+
@classmethod
|
|
342
|
+
def serialize_tables(cls, tables: Dict[str, pd.DataFrame] | Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
|
|
343
|
+
first_val = next(iter(tables.values()))
|
|
344
|
+
if isinstance(first_val, pd.DataFrame):
|
|
345
|
+
serialized_tables = {k: v.to_dict(orient="records") for k, v in tables.items()}
|
|
346
|
+
else:
|
|
347
|
+
serialized_tables = tables
|
|
348
|
+
return serialized_tables
|
|
349
|
+
|
|
350
|
+
@model_validator(mode="after")
|
|
351
|
+
def set_tables(self) -> "TabbedTable":
|
|
352
|
+
self._tables = {k: pd.DataFrame(v) for k, v in self.tables.items()}
|
|
353
|
+
return self
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
class TabbedTableFile(BaseTabbedTable):
|
|
357
|
+
file_ref: str
|
|
358
|
+
file_type: Literal["excel"] = "excel"
|
|
359
|
+
_tables: Dict[str, pd.DataFrame] = HydratedAttr()
|
|
360
|
+
_sql_db: duckdb.DuckDBPyConnection = HydratedAttr()
|
|
361
|
+
|
|
362
|
+
@classmethod
|
|
363
|
+
def from_file(
|
|
364
|
+
cls,
|
|
365
|
+
file_ref: str,
|
|
366
|
+
file: Path | IO,
|
|
367
|
+
file_type: Literal["excel"] = "excel",
|
|
368
|
+
names: Optional[List[str]] = None,
|
|
369
|
+
**kwargs,
|
|
370
|
+
) -> "TabbedTableFile":
|
|
371
|
+
if file_type == "excel":
|
|
372
|
+
xls = pd.ExcelFile(file)
|
|
373
|
+
if names is None:
|
|
374
|
+
names = xls.sheet_names
|
|
375
|
+
else:
|
|
376
|
+
for name in names:
|
|
377
|
+
if name not in xls.sheet_names:
|
|
378
|
+
raise ValueError(f"Sheet name {name} not found in file")
|
|
379
|
+
elif names is None:
|
|
380
|
+
raise ValueError(f"names must be provided for file type: {file_type}")
|
|
381
|
+
|
|
382
|
+
instance = cls(
|
|
383
|
+
file_ref=file_ref,
|
|
384
|
+
file_type=file_type,
|
|
385
|
+
names=set(names),
|
|
386
|
+
**kwargs,
|
|
387
|
+
)
|
|
388
|
+
instance._hydrate_from_file(file)
|
|
389
|
+
return instance
|
|
390
|
+
|
|
391
|
+
@classmethod
|
|
392
|
+
def from_file_store(cls, file_ref: str, file_store: FileStore, **kwargs) -> "TabbedTableFile":
|
|
393
|
+
file_result = file_store.get(file_ref)
|
|
394
|
+
if "sheet_name" in file_result.metadata:
|
|
395
|
+
sheet_name = file_result.metadata["sheet_name"]
|
|
396
|
+
else:
|
|
397
|
+
sheet_name = 0
|
|
398
|
+
return cls.from_file(
|
|
399
|
+
file_ref,
|
|
400
|
+
file_result.file,
|
|
401
|
+
file_result.type,
|
|
402
|
+
name=file_result.name,
|
|
403
|
+
sheet_name=sheet_name,
|
|
404
|
+
**kwargs,
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
def hydrate(self, file_store: FileStore):
|
|
408
|
+
file = file_store.get(self.file_ref)
|
|
409
|
+
self._hydrate_from_file(file)
|
|
410
|
+
|
|
411
|
+
def _hydrate_from_file(self, file: Path | IO):
|
|
412
|
+
if isinstance(file, IO):
|
|
413
|
+
file = file.read()
|
|
414
|
+
else:
|
|
415
|
+
file = file.read_text()
|
|
416
|
+
|
|
417
|
+
if self.file_type == "excel":
|
|
418
|
+
df_dict = pd.read_excel(file, sheet_name=self.names)
|
|
419
|
+
else:
|
|
420
|
+
raise ValueError(f"Unsupported file type: {self.file_type}")
|
|
421
|
+
|
|
422
|
+
if isinstance(df_dict, dict):
|
|
423
|
+
self._tables = {name: _process_df(df) for name, df in df_dict.items()}
|
|
424
|
+
else:
|
|
425
|
+
self._tables = {self.names[0]: _process_df(df_dict)}
|
|
426
|
+
|
|
427
|
+
@classmethod
|
|
428
|
+
def from_gsheet():
|
|
429
|
+
raise NotImplementedError("Not implemented")
|
|
430
|
+
|
|
431
|
+
@classmethod
|
|
432
|
+
def from_sharepoint():
|
|
433
|
+
raise NotImplementedError("Not implemented")
|
|
434
|
+
|
|
435
|
+
@classmethod
|
|
436
|
+
def from_s3():
|
|
437
|
+
raise NotImplementedError("Not implemented")
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
def downcast_pd_series(series: pd.Series) -> pd.Series:
|
|
441
|
+
try:
|
|
442
|
+
return pd.to_numeric(series, downcast="integer")
|
|
443
|
+
except ValueError:
|
|
444
|
+
pass
|
|
445
|
+
try:
|
|
446
|
+
return pd.to_numeric(series, downcast="float")
|
|
447
|
+
except ValueError:
|
|
448
|
+
pass
|
|
449
|
+
try:
|
|
450
|
+
with warnings.catch_warnings():
|
|
451
|
+
warnings.simplefilter("ignore", UserWarning)
|
|
452
|
+
return pd.to_datetime(series)
|
|
453
|
+
except ValueError:
|
|
454
|
+
pass
|
|
455
|
+
return series
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
def _pandas_to_mysql_dtype(dtype: np.dtype) -> str:
|
|
459
|
+
if pd.api.types.is_integer_dtype(dtype):
|
|
460
|
+
if str(dtype) in SPECIAL_INTEGER_DTYPE_MAPPING:
|
|
461
|
+
return SPECIAL_INTEGER_DTYPE_MAPPING[str(dtype)]
|
|
462
|
+
return INTEGER_DTYPE_MAPPING.get(dtype, "INT")
|
|
463
|
+
|
|
464
|
+
elif pd.api.types.is_float_dtype(dtype):
|
|
465
|
+
return FLOAT_DTYPE_MAPPING.get(dtype, "FLOAT")
|
|
466
|
+
|
|
467
|
+
elif pd.api.types.is_bool_dtype(dtype):
|
|
468
|
+
return OTHER_DTYPE_MAPPING["boolean"]
|
|
469
|
+
|
|
470
|
+
elif pd.api.types.is_datetime64_any_dtype(dtype):
|
|
471
|
+
return OTHER_DTYPE_MAPPING["datetime"]
|
|
472
|
+
|
|
473
|
+
elif pd.api.types.is_timedelta64_dtype(dtype):
|
|
474
|
+
return OTHER_DTYPE_MAPPING["timedelta"]
|
|
475
|
+
|
|
476
|
+
elif pd.api.types.is_string_dtype(dtype):
|
|
477
|
+
return OTHER_DTYPE_MAPPING["string"]
|
|
478
|
+
|
|
479
|
+
elif pd.api.types.is_categorical_dtype(dtype):
|
|
480
|
+
return OTHER_DTYPE_MAPPING["category"]
|
|
481
|
+
|
|
482
|
+
else:
|
|
483
|
+
return OTHER_DTYPE_MAPPING["default"]
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
def sanitize_name(original_name: str) -> str: # TODO: also sanitize SQL keywords
|
|
487
|
+
"""
|
|
488
|
+
Sanitizes names of files and directories.
|
|
489
|
+
|
|
490
|
+
Rules:
|
|
491
|
+
1. Remove file extension
|
|
492
|
+
2. Replace illegal characters with underscores
|
|
493
|
+
3. Replace consecutive consecutive underscores/illegal charachters with a single underscore
|
|
494
|
+
4. Replace - with _
|
|
495
|
+
5. no caps
|
|
496
|
+
4. remove leading/trailing underscores
|
|
497
|
+
5. if name starts with a number, add t_
|
|
498
|
+
6. if name is longer than 64 chars, truncate and add a hash suffix
|
|
499
|
+
|
|
500
|
+
Args:
|
|
501
|
+
original_name: The name to sanitize.
|
|
502
|
+
|
|
503
|
+
Returns:
|
|
504
|
+
The sanitized name.
|
|
505
|
+
"""
|
|
506
|
+
# Remove file extension
|
|
507
|
+
name = original_name.split(".")[0]
|
|
508
|
+
|
|
509
|
+
# Replace illegal characters with underscores
|
|
510
|
+
name = re.sub(r"[^a-zA-Z0-9_]", "_", name)
|
|
511
|
+
|
|
512
|
+
# Remove consecutive underscores
|
|
513
|
+
name = re.sub(r"_+", "_", name)
|
|
514
|
+
|
|
515
|
+
# Remove leading/trailing underscores
|
|
516
|
+
if len(name) > 2:
|
|
517
|
+
name = name.strip("_")
|
|
518
|
+
|
|
519
|
+
# Convert to lowercase
|
|
520
|
+
name = name.lower()
|
|
521
|
+
|
|
522
|
+
# Ensure name does not start with a number
|
|
523
|
+
if name[0].isdigit():
|
|
524
|
+
name = "t_" + name
|
|
525
|
+
|
|
526
|
+
# If name is longer than 64 chars, truncate and add a hash suffix
|
|
527
|
+
if len(name) > 64:
|
|
528
|
+
prefix = name[:20].rstrip("_")
|
|
529
|
+
hash_suffix = hashlib.md5(name.encode("utf-8")).hexdigest()[:8]
|
|
530
|
+
name = f"{prefix}_{hash_suffix}"
|
|
531
|
+
|
|
532
|
+
return name
|
|
533
|
+
|
|
534
|
+
|
|
535
|
+
def is_valid_table_name(name: str) -> bool:
|
|
536
|
+
"""
|
|
537
|
+
Checks if a name is a valid table name.
|
|
538
|
+
|
|
539
|
+
Args:
|
|
540
|
+
name: The name to validate.
|
|
541
|
+
|
|
542
|
+
Returns:
|
|
543
|
+
True if the name is valid, False otherwise.
|
|
544
|
+
"""
|
|
545
|
+
valid = (
|
|
546
|
+
name.islower()
|
|
547
|
+
and not name.startswith("_")
|
|
548
|
+
and not name.endswith("_")
|
|
549
|
+
and not name[0].isdigit()
|
|
550
|
+
and len(name) <= 64
|
|
551
|
+
)
|
|
552
|
+
return valid
|
|
553
|
+
|
|
554
|
+
|
|
555
|
+
def downcast_column(col: pd.Series) -> pd.Series:
|
|
556
|
+
"""
|
|
557
|
+
Downcasts a column to the smallest possible dtype.
|
|
558
|
+
"""
|
|
559
|
+
return col.apply(downcast_pd_series)
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
def _sanitize_columns(df: pd.DataFrame) -> None:
|
|
563
|
+
columns = [sanitize_name(col) for col in df.columns]
|
|
564
|
+
columns = ["No" if i == 0 and (not col or pd.isna(col)) else col for i, col in enumerate(columns)]
|
|
565
|
+
|
|
566
|
+
seen = {}
|
|
567
|
+
new_columns = []
|
|
568
|
+
for col in columns:
|
|
569
|
+
if col in seen:
|
|
570
|
+
seen[col] += 1
|
|
571
|
+
new_columns.append(f"{col}_{seen[col]}")
|
|
572
|
+
else:
|
|
573
|
+
seen[col] = 0
|
|
574
|
+
new_columns.append(col)
|
|
575
|
+
df.columns = new_columns
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
def _process_df(df: pd.DataFrame) -> pd.DataFrame:
|
|
579
|
+
"""
|
|
580
|
+
Processes the dataframe to ensure it is in the correct format.
|
|
581
|
+
"""
|
|
582
|
+
# Downcast columns
|
|
583
|
+
df = df.apply(downcast_pd_series)
|
|
584
|
+
_sanitize_columns(df)
|
|
585
|
+
return df
|
modaic/context/text.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import IO, Callable, Iterable, Iterator, List, Literal
|
|
3
|
+
|
|
4
|
+
from modaic.storage.file_store import FileStore
|
|
5
|
+
|
|
6
|
+
from .base import Context, HydratedAttr, requires_hydration
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Text(Context):
|
|
10
|
+
"""
|
|
11
|
+
Text context class.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
text: str
|
|
15
|
+
|
|
16
|
+
def chunk_text(
|
|
17
|
+
self,
|
|
18
|
+
chunk_fn: Callable[[str], Iterable[str | tuple[str, dict]]],
|
|
19
|
+
kwargs: dict = None,
|
|
20
|
+
):
|
|
21
|
+
def chunk_text_fn(text_context: "Text") -> Iterator["Text"]:
|
|
22
|
+
for chunk in chunk_fn(text_context.text, **(kwargs or {})):
|
|
23
|
+
yield Text(text=chunk)
|
|
24
|
+
|
|
25
|
+
self.chunk_with(chunk_text_fn)
|
|
26
|
+
|
|
27
|
+
@classmethod
|
|
28
|
+
def from_file(cls, file: str | Path | IO, type: Literal["txt"] = "txt", params: dict = None) -> "Text":
|
|
29
|
+
"""
|
|
30
|
+
Load a LongText instance from a file.
|
|
31
|
+
"""
|
|
32
|
+
if isinstance(file, (str, Path)):
|
|
33
|
+
file = Path(file)
|
|
34
|
+
text = file.read_text()
|
|
35
|
+
elif isinstance(file, IO):
|
|
36
|
+
text = file.read()
|
|
37
|
+
return cls(text=text, **(params or {}))
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class TextFile(Context):
|
|
41
|
+
"""
|
|
42
|
+
Text document context class.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
_text: str = HydratedAttr()
|
|
46
|
+
file_ref: str
|
|
47
|
+
file_type: Literal["txt"] = "txt"
|
|
48
|
+
|
|
49
|
+
def hydrate(self, file_store: FileStore) -> None:
|
|
50
|
+
file = file_store.get(self.file_ref)
|
|
51
|
+
if isinstance(file, Path):
|
|
52
|
+
file = file.read_text()
|
|
53
|
+
else:
|
|
54
|
+
file = file.read()
|
|
55
|
+
self._text = file
|
|
56
|
+
|
|
57
|
+
@classmethod
|
|
58
|
+
def from_file_store(
|
|
59
|
+
cls,
|
|
60
|
+
file_ref: str,
|
|
61
|
+
file_store: FileStore,
|
|
62
|
+
params: dict = None,
|
|
63
|
+
) -> "TextFile":
|
|
64
|
+
"""
|
|
65
|
+
Load a TextFile instance from a file.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
file: The file to load.
|
|
69
|
+
file_store: The file store to use.
|
|
70
|
+
type: The type of file to load.
|
|
71
|
+
params: The parameters to pass to the constructor.
|
|
72
|
+
"""
|
|
73
|
+
file = file_store.get(file_ref)
|
|
74
|
+
instance = cls(file_ref=file, **(params or {}))
|
|
75
|
+
instance.hydrate(file_store)
|
|
76
|
+
return instance
|
|
77
|
+
|
|
78
|
+
@requires_hydration
|
|
79
|
+
def dump(self) -> None:
|
|
80
|
+
return self._text
|
|
81
|
+
|
|
82
|
+
@requires_hydration
|
|
83
|
+
def chunk_text(
|
|
84
|
+
self,
|
|
85
|
+
chunk_fn: Callable[[str], List[str | tuple[str, dict]]],
|
|
86
|
+
kwargs: dict = None,
|
|
87
|
+
):
|
|
88
|
+
def chunk_text_fn(text_context: "TextFile") -> List["Text"]:
|
|
89
|
+
chunks = []
|
|
90
|
+
for chunk in chunk_fn(text_context._text, **(kwargs or {})):
|
|
91
|
+
chunks.append(Text(text=chunk))
|
|
92
|
+
return chunks
|
|
93
|
+
|
|
94
|
+
self.apply_to_chunks(chunk_text_fn)
|