tracebloc-ingestor 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. tracebloc_ingestor-0.1.0/LICENSE +21 -0
  2. tracebloc_ingestor-0.1.0/PKG-INFO +48 -0
  3. tracebloc_ingestor-0.1.0/setup.cfg +4 -0
  4. tracebloc_ingestor-0.1.0/setup.py +26 -0
  5. tracebloc_ingestor-0.1.0/tracebloc_ingestor/__init__.py +24 -0
  6. tracebloc_ingestor-0.1.0/tracebloc_ingestor/api/__init__.py +0 -0
  7. tracebloc_ingestor-0.1.0/tracebloc_ingestor/api/client.py +267 -0
  8. tracebloc_ingestor-0.1.0/tracebloc_ingestor/config.py +41 -0
  9. tracebloc_ingestor-0.1.0/tracebloc_ingestor/database.py +247 -0
  10. tracebloc_ingestor-0.1.0/tracebloc_ingestor/examples/__init__.py +12 -0
  11. tracebloc_ingestor-0.1.0/tracebloc_ingestor/examples/blob_ingestion.py +168 -0
  12. tracebloc_ingestor-0.1.0/tracebloc_ingestor/examples/csv_ingestion.py +115 -0
  13. tracebloc_ingestor-0.1.0/tracebloc_ingestor/examples/custom_processor.py +115 -0
  14. tracebloc_ingestor-0.1.0/tracebloc_ingestor/examples/image_ingestion.py +225 -0
  15. tracebloc_ingestor-0.1.0/tracebloc_ingestor/examples/json_ingestion.py +120 -0
  16. tracebloc_ingestor-0.1.0/tracebloc_ingestor/ingestors/__init__.py +18 -0
  17. tracebloc_ingestor-0.1.0/tracebloc_ingestor/ingestors/base.py +389 -0
  18. tracebloc_ingestor-0.1.0/tracebloc_ingestor/ingestors/csv_ingestor.py +225 -0
  19. tracebloc_ingestor-0.1.0/tracebloc_ingestor/ingestors/json_ingestor.py +231 -0
  20. tracebloc_ingestor-0.1.0/tracebloc_ingestor/processors/__init__.py +12 -0
  21. tracebloc_ingestor-0.1.0/tracebloc_ingestor/processors/base.py +27 -0
  22. tracebloc_ingestor-0.1.0/tracebloc_ingestor/utils/__init__.py +0 -0
  23. tracebloc_ingestor-0.1.0/tracebloc_ingestor/utils/constants.py +62 -0
  24. tracebloc_ingestor-0.1.0/tracebloc_ingestor/utils/logging.py +55 -0
  25. tracebloc_ingestor-0.1.0/tracebloc_ingestor.egg-info/PKG-INFO +48 -0
  26. tracebloc_ingestor-0.1.0/tracebloc_ingestor.egg-info/SOURCES.txt +27 -0
  27. tracebloc_ingestor-0.1.0/tracebloc_ingestor.egg-info/dependency_links.txt +1 -0
  28. tracebloc_ingestor-0.1.0/tracebloc_ingestor.egg-info/requires.txt +18 -0
  29. tracebloc_ingestor-0.1.0/tracebloc_ingestor.egg-info/top_level.txt +1 -0
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2021 tracebloc
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,48 @@
1
+ Metadata-Version: 2.1
2
+ Name: tracebloc_ingestor
3
+ Version: 0.1.0
4
+ Summary: A flexible data ingestion library for various file formats
5
+ Home-page: https://github.com/tracebloc/data-ingestors
6
+ Author: Tracebloc
7
+ Author-email: support@tracebloc.com
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: License :: OSI Approved :: MIT License
10
+ Classifier: Operating System :: OS Independent
11
+ Requires-Python: >=3.8
12
+ Description-Content-Type: text/markdown
13
+ License-File: LICENSE
14
+
15
+ # Data Ingestors
16
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
17
+
18
+ ## 📄 Description
19
+ A robust data ingestion framework for machine learning pipelines. This repository provides tools and utilities for managing, processing, and validating training/test datasets. It supports various data sources, formats, and processing pipelines, making it easier to create and maintain ML datasets.
20
+
21
+ ## 🛠️ Tech Stack
22
+ - Python 3.x
23
+ - Docker (for containerization)
24
+ - Data processing libraries (Pandas, NumPy)
25
+
26
+ ## 🚀 Installation & Usage Instructions
27
+ 1. Clone the repository
28
+ 2. Install dependencies:
29
+ ```bash
30
+ pip install -r src/requirements.txt
31
+ ```
32
+ 3. Configure your environment
33
+ 4. Follow the documentation guide to [Create Your Training/Test Dataset](https://traceblocdocsdev.azureedge.net/environment-setup/create-your-dataset)
34
+
35
+ ## 📦 Features
36
+ - Multi-source data ingestion
37
+ - Data validation and preprocessing
38
+ - Database integration
39
+ - API endpoints for data management
40
+ - Containerized deployment
41
+ - Kubernetes support
42
+
43
+
44
+ ## 📜 License
45
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
46
+
47
+ ## 📞 Support
48
+ For additional support or questions, please refer to our documentation or contact the Tracebloc support team at `support@tracebloc.io`.
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,26 @@
1
+ from setuptools import setup, find_packages
2
+
3
+ with open("README.md", "r", encoding="utf-8") as fh:
4
+ long_description = fh.read()
5
+
6
+ with open("requirements.txt", "r") as f:
7
+ requirements = f.read().splitlines()
8
+
9
+ setup(
10
+ name="tracebloc_ingestor",
11
+ version="0.1.0",
12
+ author="Tracebloc",
13
+ author_email="support@tracebloc.com",
14
+ description="A flexible data ingestion library for various file formats",
15
+ long_description=long_description,
16
+ long_description_content_type="text/markdown",
17
+ url="https://github.com/tracebloc/data-ingestors",
18
+ packages=find_packages(),
19
+ classifiers=[
20
+ "Programming Language :: Python :: 3",
21
+ "License :: OSI Approved :: MIT License",
22
+ "Operating System :: OS Independent",
23
+ ],
24
+ python_requires=">=3.8",
25
+ install_requires=requirements,
26
+ )
@@ -0,0 +1,24 @@
1
+ """Tracebloc Data Ingestor Package.
2
+
3
+ A flexible and extensible framework for ingesting data from various sources into a database
4
+ and optionally sending it to an API. The package provides base classes for creating custom
5
+ ingestors and processors, along with built-in support for common data formats.
6
+ """
7
+
8
+ from .config import Config
9
+ from .database import Database
10
+ from .api.client import APIClient
11
+ from .ingestors import BaseIngestor, CSVIngestor, JSONIngestor
12
+ from .processors.base import BaseProcessor
13
+
14
+ __version__ = '0.1.0'
15
+
16
+ __all__ = [
17
+ 'Config',
18
+ 'Database',
19
+ 'APIClient',
20
+ 'BaseIngestor',
21
+ 'CSVIngestor',
22
+ 'JSONIngestor',
23
+ 'BaseProcessor'
24
+ ]
@@ -0,0 +1,267 @@
1
+ from typing import List, Tuple, Dict, Any
2
+ import requests, json
3
+ import logging
4
+ from requests.adapters import HTTPAdapter
5
+ from requests.packages.urllib3.util.retry import Retry
6
+ from ..config import Config
7
+ from ..utils.logging import setup_logging
8
+ from ..utils.constants import DataCategory, API_TIMEOUT
9
+
10
+ # Configure unified logging with config
11
+ config = Config()
12
+ setup_logging(config)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ class APIClient:
16
+ def __init__(self, config: Config):
17
+ self.config = config
18
+ self.session = self._create_session()
19
+ self.token = self.authenticate()
20
+
21
+ def _create_session(self) -> requests.Session:
22
+ session = requests.Session()
23
+
24
+ # Configure retry strategy
25
+ retry_strategy = Retry(
26
+ total=3,
27
+ backoff_factor=1,
28
+ status_forcelist=[500, 502, 503, 504]
29
+ )
30
+
31
+ adapter = HTTPAdapter(max_retries=retry_strategy)
32
+ session.mount("http://", adapter)
33
+ session.mount("https://", adapter)
34
+
35
+ return session
36
+
37
+ def authenticate(self) -> str:
38
+ """Authenticate and return the token."""
39
+ try:
40
+ response = self.session.post(
41
+ f"{self.config.API_ENDPOINT}/api-token-auth/",
42
+ json={"username": self.config.CLIENT_USERNAME, "password": self.config.CLIENT_PASSWORD},
43
+ timeout=API_TIMEOUT
44
+ )
45
+ response.raise_for_status()
46
+ logger.info(f"Authentication response: {response.json()}")
47
+ return response.json().get("token")
48
+ except requests.exceptions.RequestException as e:
49
+ logger.error(f"Error during authentication: {str(e)}")
50
+ raise
51
+
52
+ def send_batch(self, records: List[Tuple[int, Dict[str, Any]]], table_name: str, ingestor_id: str) -> bool:
53
+ """
54
+ Send a batch of records to the remote API.
55
+
56
+ Args:
57
+ records: List of tuples containing (id, record) pairs
58
+ table_name: Name of the table to send data to
59
+ ingestor_id: Unique ID for the ingestor
60
+ Returns:
61
+ bool: True if successful, False otherwise
62
+ """
63
+ try:
64
+ payload = json.dumps([
65
+ {
66
+ "data_id": record_data.get("data_id"),
67
+ "company": self.config.COMPANY,
68
+ "data_intent": record_data.get("data_intent", "train"),
69
+ "label": record_data.get("label", ""),
70
+ "is_sample": False,
71
+ "is_active": True,
72
+ "injestor_id": ingestor_id,
73
+ # "data": record_data
74
+ }
75
+ for _, record_data in records
76
+ ])
77
+
78
+ logger.info(f"Data to send: {payload}")
79
+
80
+ headers = {
81
+ "Authorization": f"TOKEN {self.token}",
82
+ "Content-Type": "application/json"
83
+ }
84
+
85
+ response = self.session.post(
86
+ f"{self.config.API_ENDPOINT}/global_meta/{table_name}/",
87
+ data=payload,
88
+ headers=headers,
89
+ timeout=API_TIMEOUT
90
+ )
91
+
92
+ response.raise_for_status()
93
+ logger.info(f"Successfully sent batch. Response: {response.json()}")
94
+ return True
95
+
96
+ except requests.exceptions.RequestException as e:
97
+ logger.error(f"Error sending batch to API: {str(e)}")
98
+ if hasattr(e.response, 'text'):
99
+ logger.error(f"Error response: {e.response.text}")
100
+ return False
101
+
102
+ def send_global_meta_meta(self, table_name: str, schema: Dict[str, str]) -> bool:
103
+ """
104
+ Sends global metadata, including the schema, to the remote server.
105
+
106
+ Args:
107
+ table_name: The type of the dataset
108
+ schema: A dictionary representing the schema
109
+
110
+ Returns:
111
+ bool: True if successful, False otherwise
112
+ """
113
+ try:
114
+ payload = json.dumps({
115
+ "table_name": table_name,
116
+ "schema": schema
117
+ })
118
+
119
+ logger.info(f"Global metadata to send: {(payload)}")
120
+
121
+ headers = {
122
+ "Authorization": f"TOKEN {self.token}",
123
+ "Content-Type": "application/json"
124
+ }
125
+
126
+ response = self.session.post(
127
+ f"{self.config.API_ENDPOINT}/global_meta/global_metadata/",
128
+ data=payload,
129
+ headers=headers,
130
+ timeout=API_TIMEOUT
131
+ )
132
+
133
+ response.raise_for_status()
134
+ logger.info(f"Successfully sent global metadata. Response: {response.json()}")
135
+ return True
136
+
137
+ except requests.exceptions.RequestException as e:
138
+ logger.error(f"Error sending global metadata to API: {str(e)}")
139
+ if hasattr(e.response, 'text'):
140
+ logger.error(f"Error response: {e.response.text}")
141
+ return False
142
+
143
+ def send_generate_edge_label_meta(self, table_name: str, ingestor_id: str) -> bool:
144
+ """
145
+ Send a request to generate edge label metadata for the specified dataset type.
146
+
147
+ Args:
148
+ table_name: The type of the dataset
149
+
150
+ Returns:
151
+ bool: True if successful, False otherwise
152
+ """
153
+ try:
154
+ url = f"{self.config.API_ENDPOINT}/global_meta/generate-edge-labels-meta/?table_name={table_name}&injestor_id={ingestor_id}"
155
+ headers = {
156
+ "Authorization": f"TOKEN {self.token}"
157
+ }
158
+
159
+ logger.info(f"Sending request to generate edge label metadata for dataset type: {table_name}")
160
+ response = self.session.get(url, headers=headers, timeout=API_TIMEOUT)
161
+
162
+ response.raise_for_status()
163
+ logger.info(f"Successfully generated edge label metadata. Response")
164
+ return True
165
+
166
+ except requests.exceptions.RequestException as e:
167
+ logger.error(f"Error generating edge label metadata: {str(e)}")
168
+ if hasattr(e.response, 'text'):
169
+ logger.error(f"Error response: {e.response.text}")
170
+ return False
171
+
172
+ def prepare_dataset(self, category: str, ingestor_id: str) -> bool:
173
+ """
174
+ Prepare data for a specific category and ingestor.
175
+
176
+ Args:
177
+ category: The category of data (must be one of DataCategory values)
178
+ injester_id: The unique identifier for the injester
179
+
180
+ Returns:
181
+ bool: True if successful, False otherwise
182
+ """
183
+ if not DataCategory.is_valid_category(category):
184
+ logger.error(f"Invalid category: {category}")
185
+ return False
186
+
187
+ try:
188
+ url = f"{self.config.API_ENDPOINT}/global_meta/prepare/?category={category}&injestor_id={ingestor_id}"
189
+ headers = {
190
+ "Authorization": f"TOKEN {self.token}"
191
+ }
192
+
193
+ logger.info(f"Sending prepare request for category: {category}, injester_id: {ingestor_id}")
194
+ response = self.session.get(url, headers=headers, timeout=API_TIMEOUT)
195
+
196
+ response.raise_for_status()
197
+ logger.info(f"Successfully prepared data. Response: {response.json()}")
198
+ return True
199
+
200
+ except requests.exceptions.RequestException as e:
201
+ logger.error(f"Error preparing data: {str(e)}")
202
+ if hasattr(e.response, 'text'):
203
+ logger.error(f"Error response: {e.response.text}")
204
+ return False
205
+
206
+ def create_dataset(self, requires_gpu: bool = False, allow_feature_modification: bool = False, ingestor_id: str = None, category: str = None) -> Dict[str, Any]:
207
+ """
208
+ Create a new dataset with the specified parameters.
209
+
210
+ Args:
211
+ title: The title of the dataset (if None, will be generated from category and ingestor_id)
212
+ requires_gpu: Whether the dataset requires GPU processing
213
+ allow_feature_modification: Whether feature modification is allowed
214
+ ingestor_id: The unique identifier for the ingestor
215
+
216
+ Returns:
217
+ Dict[str, Any]: The created dataset information if successful
218
+
219
+ Raises:
220
+ requests.exceptions.RequestException: If the API request fails
221
+ """
222
+ try:
223
+ # Generate title from category and ingestor_id if not provided
224
+ if config.TITLE is None:
225
+ title = f"{category}_{ingestor_id}"
226
+ else:
227
+ title = config.TITLE # Fallback to config title if no ingestor_id
228
+
229
+ if category == DataCategory.TABULAR_CLASSIFICATION:
230
+ allow_feature_modification = True
231
+ else:
232
+ allow_feature_modification = False
233
+
234
+ payload = json.dumps({
235
+ "title": title,
236
+ "requires_gpu": requires_gpu,
237
+ "allow_feature_modification": allow_feature_modification
238
+ })
239
+
240
+ logger.info(f"Creating dataset with payload: {payload}")
241
+
242
+ headers = {
243
+ "Authorization": f"TOKEN {self.token}",
244
+ "Content-Type": "application/json"
245
+ }
246
+
247
+ response = self.session.post(
248
+ f"{self.config.API_ENDPOINT}/dataset/",
249
+ data=payload,
250
+ headers=headers,
251
+ timeout=API_TIMEOUT
252
+ )
253
+
254
+ response.raise_for_status()
255
+ logger.info(f"Successfully created dataset. Response: {response.json()}")
256
+ return response.json()
257
+
258
+ except requests.exceptions.RequestException as e:
259
+ logger.error(f"Error creating dataset: {str(e)}")
260
+ if hasattr(e.response, 'text'):
261
+ logger.error(f"Error response: {e.response.text}")
262
+ raise
263
+
264
+ def __del__(self):
265
+ """Cleanup when the client is destroyed"""
266
+ if hasattr(self, 'session'):
267
+ self.session.close()
@@ -0,0 +1,41 @@
1
+ from typing import Dict, Any, Optional
2
+ import os
3
+ from dataclasses import dataclass
4
+ import logging
5
+
6
+ @dataclass
7
+ class Config:
8
+ DB_HOST: str = os.getenv("MYSQL_HOST", "localhost")
9
+ DB_PORT: int = int(os.getenv("MYSQL_PORT", "3306"))
10
+ DB_USER: str = os.getenv("MYSQL_USER", "root")
11
+ DB_PASSWORD: str = os.getenv("MYSQL_PASSWORD", "")
12
+ DB_NAME: str = os.getenv("MYSQL_DATABASE", "ingestor_db")
13
+
14
+ BATCH_SIZE: int = int(os.getenv("BATCH_SIZE", "10"))
15
+
16
+ # Define API endpoints for different environments
17
+ API_ENDPOINTS = {
18
+ "dev": "https://dev-api.tracebloc.io",
19
+ "stg": "https://stg-api.tracebloc.io",
20
+ "prod": "https://api.tracebloc.io"
21
+ }
22
+
23
+ # Get environment and set appropriate API endpoint, default to dev
24
+ EDGE_ENV: str = os.getenv("EDGE_ENV", "dev")
25
+ API_ENDPOINT: str = API_ENDPOINTS.get(EDGE_ENV, API_ENDPOINTS["dev"])
26
+
27
+ CLIENT_USERNAME: str = os.getenv("EDGE_USERNAME", "")
28
+ CLIENT_PASSWORD: str = os.getenv("EDGE_PASSWORD", "")
29
+
30
+ STORAGE_PATH: str = os.getenv("STORAGE_PATH", "/data/shared")
31
+ SRC_PATH: str = os.getenv("SRC_PATH", "") # path to the source data
32
+ DEST_PATH: str = os.path.join(os.getenv("DEST_PATH", ""), os.getenv("TABLE_NAME", "")) # path to the destination data with table name
33
+ LABEL_FILE: str = os.getenv("LABEL_FILE", "")
34
+ COMPANY: str = os.getenv("COMPANY", "")
35
+ TABLE_NAME: str = os.getenv("TABLE_NAME", "")
36
+ TITLE: str = os.getenv("TITLE", "")
37
+
38
+ # Logging configuration
39
+ LOG_LEVEL: int = int(os.getenv("LOG_LEVEL", str(logging.INFO)))
40
+ LOG_FORMAT: Optional[str] = os.getenv("LOG_FORMAT", None)
41
+ LOG_DATE_FORMAT: Optional[str] = os.getenv("LOG_DATE_FORMAT", None)
@@ -0,0 +1,247 @@
1
+ from sqlalchemy import create_engine, MetaData, Table, Column, BigInteger, DateTime, text, Text, Integer, String, Float, Boolean, inspect
2
+ from sqlalchemy.engine import Engine
3
+ from sqlalchemy.dialects.mysql import insert, LONGBLOB, BLOB
4
+ import logging
5
+ from urllib.parse import quote
6
+ from typing import List, Dict, Any, Optional
7
+ from datetime import datetime
8
+ from .config import Config
9
+ from .utils.logging import setup_logging
10
+
11
+ # Configure unified logging with config
12
+ config = Config()
13
+ setup_logging(config)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ class Database:
17
+ def __init__(self, config: Config):
18
+ self.config = config
19
+ self.engine = self._create_engine()
20
+ self.metadata = MetaData()
21
+ self.tables: Dict[str, Table] = {}
22
+ self.unique_id_column: Optional[str] = None # Store table-specific unique ID column mappings
23
+
24
+ def _create_engine(self) -> Engine:
25
+ # First create database if it doesn't exist
26
+ base_connection_string = (
27
+ f"mysql+mysqldb://{self.config.DB_USER}:{quote(self.config.DB_PASSWORD)}"
28
+ f"@{self.config.DB_HOST}:{self.config.DB_PORT}"
29
+ )
30
+ engine = create_engine(base_connection_string, pool_pre_ping=True)
31
+
32
+ with engine.connect() as connection:
33
+ connection.execute(text(f"CREATE DATABASE IF NOT EXISTS {self.config.DB_NAME}"))
34
+ connection.commit()
35
+
36
+ # Now connect to the specific database
37
+ connection_string = f"{base_connection_string}/{self.config.DB_NAME}"
38
+ print(connection_string)
39
+ return create_engine(connection_string, pool_pre_ping=True)
40
+
41
+ def _get_sqlalchemy_type(self, mysql_type: str):
42
+ type_mapping = {
43
+ 'VARCHAR': String,
44
+ 'TEXT': Text,
45
+ 'INT': Integer,
46
+ 'BIGINT': BigInteger,
47
+ 'FLOAT': Float,
48
+ 'BOOLEAN': Boolean,
49
+ 'DATETIME': DateTime,
50
+ 'TIMESTAMP': DateTime,
51
+ 'BLOB': BLOB,
52
+ 'LONGBLOB': LONGBLOB,
53
+ }
54
+
55
+ for sql_type, alchemy_type in type_mapping.items():
56
+ if sql_type in mysql_type.upper():
57
+ length = None
58
+ if '(' in mysql_type:
59
+ length = int(mysql_type.split('(')[1].split(')')[0])
60
+ return alchemy_type(length) if length else alchemy_type
61
+
62
+ raise ValueError(f"Unsupported MySQL type: {mysql_type}")
63
+
64
+ def create_table(self, table_name: str, schema: Dict[str, str]):
65
+ """
66
+ Creates a table if it doesn't exist, or returns existing table
67
+
68
+ Args:
69
+ table_name: Name of the table
70
+ schema: Dictionary defining the table schema
71
+
72
+ Returns:
73
+ SQLAlchemy Table object
74
+ """
75
+ # Return existing table if already created
76
+ if table_name in self.tables:
77
+ return self.tables[table_name]
78
+
79
+ # Check if table exists in database
80
+ inspector = inspect(self.engine)
81
+ if table_name in inspector.get_table_names():
82
+ # Reflect existing table using MetaData
83
+ self.metadata.reflect(self.engine, only=[table_name])
84
+ table = self.metadata.tables[table_name]
85
+ self.tables[table_name] = table
86
+ return table
87
+
88
+ # Define standard columns that should be present in all tables
89
+ standard_columns = [
90
+ Column('id', BigInteger, primary_key=True, autoincrement=True),
91
+ Column('created_at', DateTime, server_default=text('CURRENT_TIMESTAMP')),
92
+ Column('updated_at', DateTime, server_default=text('CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP')),
93
+ Column('status', Integer, server_default=text('0')), # 1 for active, 0 for inactive
94
+ Column('label', String(255), nullable=True),
95
+ Column('data_intent', String(100), nullable=True),
96
+ Column('data_id', String(255), unique=True, nullable=False),
97
+ Column('annotation', Text, nullable=True),
98
+ Column('ingestor_id', String(255), nullable=True)
99
+ ]
100
+
101
+ # Add custom columns from the schema
102
+ custom_columns = [
103
+ Column(column_name, self._get_sqlalchemy_type(mysql_type))
104
+ for column_name, mysql_type in schema.items()
105
+ ]
106
+
107
+ # Combine standard and custom columns
108
+ table = Table(table_name, self.metadata, *(standard_columns + custom_columns))
109
+ self.tables[table_name] = table
110
+
111
+ # Create table if it doesn't exist
112
+ self.metadata.create_all(self.engine, tables=[table])
113
+ return table
114
+
115
+ def insert_batch(self, table_name: str, records: List[Dict[str, Any]]) -> Dict[str, Any]:
116
+ """
117
+ Insert or update batch of records based on data_id
118
+
119
+ Args:
120
+ table_name: Name of the target table
121
+ records: List of records to insert/update
122
+
123
+ Returns:
124
+ Dictionary containing:
125
+ - success_ids: List of successfully processed record IDs
126
+ - failures: List of dictionaries containing failed records and their error messages
127
+ """
128
+ if not records:
129
+ return {"success_ids": [], "failures": []}
130
+
131
+ table = self.tables[table_name]
132
+ result = {"success_ids": [], "failures": []}
133
+
134
+ try:
135
+ with self.engine.connect() as connection:
136
+ current_time = datetime.now()
137
+ processed_records = []
138
+
139
+ for record in records:
140
+ processed_record = {
141
+ **record,
142
+ 'updated_at': current_time,
143
+ }
144
+
145
+ if 'created_at' not in record:
146
+ processed_record['created_at'] = current_time
147
+
148
+ processed_records.append(processed_record)
149
+
150
+ # Create an "INSERT ... ON DUPLICATE KEY UPDATE" statement
151
+ insert_stmt = insert(table)
152
+ update_dict = {
153
+ column.name: text(f"VALUES({column.name})")
154
+ for column in table.columns
155
+ if column.name not in ['id', 'created_at', 'data_id']
156
+ }
157
+
158
+ try:
159
+ # Execute upsert
160
+ connection.execute(
161
+ insert_stmt.values(processed_records).on_duplicate_key_update(**update_dict)
162
+ )
163
+ connection.commit()
164
+
165
+ # Get IDs for successfully processed records
166
+ data_ids = [record['data_id'] for record in records]
167
+ select_stmt = table.select().where(table.c.data_id.in_(data_ids))
168
+ rows = connection.execute(select_stmt).fetchall()
169
+ result["success_ids"] = [row.id for row in rows]
170
+
171
+ except Exception as e:
172
+ # If batch insert fails, try one by one to identify problematic records
173
+ connection.rollback()
174
+ logger.warning(f"Batch insert failed, attempting individual inserts: {str(e)}")
175
+
176
+ for record in processed_records:
177
+ try:
178
+ stmt = insert_stmt.values([record]).on_duplicate_key_update(**update_dict)
179
+ connection.execute(stmt)
180
+ connection.commit()
181
+
182
+ # Get ID for the successful record
183
+ select_stmt = table.select().where(table.c.data_id == record['data_id'])
184
+ row = connection.execute(select_stmt).fetchone()
185
+ if row:
186
+ result["success_ids"].append(row.id)
187
+
188
+ except Exception as individual_error:
189
+ result["failures"].append({
190
+ "record": record,
191
+ "error": str(individual_error)
192
+ })
193
+ connection.rollback()
194
+ logger.error(f"Failed to process record {record['data_id']}: {str(individual_error)}")
195
+
196
+ except Exception as e:
197
+ logger.error(f"Database connection error in insert_batch: {str(e)}")
198
+ result["failures"].extend([{
199
+ "record": record,
200
+ "error": f"Database connection error: {str(e)}"
201
+ } for record in records])
202
+
203
+ return result["success_ids"], result["failures"]
204
+
205
+ def get_table_schema(self, table_name: str) -> Dict[str, str]:
206
+ """
207
+ Returns the schema of a table as a dictionary mapping column names to their MySQL types.
208
+ Args:
209
+ table_name: Name of the table to inspect
210
+
211
+ Returns:
212
+ Dictionary of column names and their MySQL types
213
+ """
214
+
215
+ inspector = inspect(self.engine)
216
+
217
+ # Get all columns from the table
218
+ columns = inspector.get_columns(table_name)
219
+
220
+ # Convert SQLAlchemy types back to MySQL types
221
+ type_mapping = {
222
+ 'String': 'VARCHAR',
223
+ 'Text': 'TEXT',
224
+ 'Integer': 'INT',
225
+ 'BigInteger': 'BIGINT',
226
+ 'Float': 'FLOAT',
227
+ 'Boolean': 'BOOLEAN',
228
+ 'DateTime': 'DATETIME',
229
+ 'BLOB': 'BLOB',
230
+ 'LONGBLOB': 'LONGBLOB'
231
+ }
232
+
233
+ schema = {}
234
+ for column in columns:
235
+ # Get the type name
236
+ type_name = column['type'].__class__.__name__
237
+
238
+ # Convert SQLAlchemy type to MySQL type
239
+ mysql_type = type_mapping.get(type_name, 'VARCHAR')
240
+
241
+ # Add length for VARCHAR types
242
+ if mysql_type == 'VARCHAR' and hasattr(column['type'], 'length'):
243
+ mysql_type = f"{mysql_type}({column['type'].length})"
244
+
245
+ schema[column['name']] = mysql_type
246
+
247
+ return schema