sdk-seshat-python 0.3.14__tar.gz → 0.3.16__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/PKG-INFO +2 -1
  2. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/pyproject.toml +3 -2
  3. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/source/exceptions.py +4 -0
  4. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/source/mixins.py +22 -4
  5. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/source/saver/database.py +64 -12
  6. sdk_seshat_python-0.3.16/seshat/source/saver/utils/postgres.py +105 -0
  7. sdk_seshat_python-0.3.14/seshat/source/saver/utils/postgres.py +0 -22
  8. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/LICENSE +0 -0
  9. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/README.md +0 -0
  10. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/__init__.py +0 -0
  11. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/__main__.py +0 -0
  12. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/data_class/__init__.py +0 -0
  13. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/data_class/base.py +0 -0
  14. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/data_class/pandas.py +0 -0
  15. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/data_class/pyspark.py +0 -0
  16. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/evaluation/__init__.py +0 -0
  17. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/evaluation/base.py +0 -0
  18. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/evaluation/evaluator/__init__.py +0 -0
  19. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/evaluation/evaluator/base.py +0 -0
  20. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/evaluation/evaluator/general/__init__.py +0 -0
  21. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/evaluation/evaluator/general/classification.py +0 -0
  22. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/evaluation/evaluator/general/clustering.py +0 -0
  23. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/evaluation/evaluator/general/regression.py +0 -0
  24. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/evaluation/evaluator/recommendation/__init__.py +0 -0
  25. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/evaluation/evaluator/recommendation/diversity.py +0 -0
  26. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/evaluation/evaluator/recommendation/ranking.py +0 -0
  27. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/feature_view/__init__.py +0 -0
  28. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/feature_view/base.py +0 -0
  29. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/general/__init__.py +0 -0
  30. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/general/command/__init__.py +0 -0
  31. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/general/command/base.py +0 -0
  32. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/general/command/code_inspect.py +0 -0
  33. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/general/command/job_status.py +0 -0
  34. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/general/command/setup_project.py +0 -0
  35. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/general/command/submit_to_network.py +0 -0
  36. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/general/config.py +0 -0
  37. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/general/exceptions.py +0 -0
  38. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/general/lazy_config.py +0 -0
  39. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/general/models.py +0 -0
  40. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/general/template/README.md-tmpl +0 -0
  41. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/general/template/config.py-tmpl +0 -0
  42. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/general/template/env-templ +0 -0
  43. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/general/template/jobignore-tmpl +0 -0
  44. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/general/template/pyproject._toml-tmpl +0 -0
  45. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/general/template/recommender-jupyter.ipynb-tmpl +0 -0
  46. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/general/template/recommender.py-tmpl +0 -0
  47. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/profiler/__init__.py +0 -0
  48. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/profiler/base.py +0 -0
  49. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/profiler/decorator.py +0 -0
  50. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/profiler/format.py +0 -0
  51. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/source/__init__.py +0 -0
  52. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/source/base.py +0 -0
  53. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/source/database/__init__.py +0 -0
  54. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/source/database/base.py +0 -0
  55. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/source/flip_side/__init__.py +0 -0
  56. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/source/flip_side/base.py +0 -0
  57. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/source/local/__init__.py +0 -0
  58. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/source/local/base.py +0 -0
  59. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/source/multisource/__init__.py +0 -0
  60. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/source/multisource/base.py +0 -0
  61. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/source/saver/__init__.py +0 -0
  62. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/source/saver/base.py +0 -0
  63. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/source/saver/utils/__init__.py +0 -0
  64. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/__init__.py +0 -0
  65. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/aggregator/__init__.py +0 -0
  66. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/aggregator/base.py +0 -0
  67. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/augmenter/__init__.py +0 -0
  68. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/augmenter/base.py +0 -0
  69. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/base.py +0 -0
  70. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/deriver/__init__.py +0 -0
  71. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/deriver/base.py +0 -0
  72. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/deriver/from_database.py +0 -0
  73. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/imputer/__init__.py +0 -0
  74. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/imputer/base.py +0 -0
  75. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/merger/__init__.py +0 -0
  76. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/merger/base.py +0 -0
  77. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/merger/nested_key.py +0 -0
  78. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/pipeline/__init__.py +0 -0
  79. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/pipeline/base.py +0 -0
  80. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/pipeline/branch.py +0 -0
  81. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/pipeline/recommendation/__init__.py +0 -0
  82. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/pipeline/recommendation/address_pipeline.py +0 -0
  83. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/pseudo/__init__.py +0 -0
  84. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/pseudo/table_existence.py +0 -0
  85. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/reducer/__init__.py +0 -0
  86. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/reducer/base.py +0 -0
  87. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/scaler/__init__.py +0 -0
  88. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/scaler/base.py +0 -0
  89. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/schema/__init__.py +0 -0
  90. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/schema/base.py +0 -0
  91. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/splitter/__init__.py +0 -0
  92. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/splitter/base.py +0 -0
  93. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/splitter/block/__init__.py +0 -0
  94. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/splitter/block/base.py +0 -0
  95. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/splitter/random/__init__.py +0 -0
  96. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/splitter/random/base.py +0 -0
  97. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/splitter/time_line/__init__.py +0 -0
  98. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/splitter/time_line/base.py +0 -0
  99. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/trimmer/__init__.py +0 -0
  100. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/trimmer/base.py +0 -0
  101. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/vectorizer/__init__.py +0 -0
  102. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/vectorizer/base.py +0 -0
  103. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/vectorizer/cosine_similarity.py +0 -0
  104. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/vectorizer/pivot.py +0 -0
  105. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/transformer/vectorizer/utils.py +0 -0
  106. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/utils/__init__.py +0 -0
  107. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/utils/batcher.py +0 -0
  108. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/utils/binary_utils.py +0 -0
  109. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/utils/clean_json.py +0 -0
  110. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/utils/col_to_list.py +0 -0
  111. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/utils/contracts.py +0 -0
  112. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/utils/file.py +0 -0
  113. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/utils/file_cryptography.py +0 -0
  114. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/utils/filter_json.py +0 -0
  115. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/utils/grouper.py +0 -0
  116. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/utils/jobignore.py +0 -0
  117. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/utils/join_columns_to_list.py +0 -0
  118. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/utils/join_str.py +0 -0
  119. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/utils/llm_client/__init__.py +0 -0
  120. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/utils/llm_client/chatbot_factory.py +0 -0
  121. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/utils/logging/__init__.py +0 -0
  122. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/utils/logging/base_logger.py +0 -0
  123. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/utils/logging/console_logger.py +0 -0
  124. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/utils/logging/logstash_logger.py +0 -0
  125. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/utils/logging/multi_logger.py +0 -0
  126. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/utils/memory.py +0 -0
  127. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/utils/mixin.py +0 -0
  128. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/utils/obfuscate.py +0 -0
  129. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/utils/package_utils.py +0 -0
  130. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/utils/pandas_func.py +0 -0
  131. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/utils/patching.py +0 -0
  132. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/utils/pyspark_func.py +0 -0
  133. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/utils/rest.py +0 -0
  134. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/utils/singleton.py +0 -0
  135. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/utils/validation.py +0 -0
  136. {sdk_seshat_python-0.3.14 → sdk_seshat_python-0.3.16}/seshat/utils/zip_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: sdk-seshat-python
3
- Version: 0.3.14
3
+ Version: 0.3.16
4
4
  Summary: Seshat python SDK is a library to help create ML data pipelines.
5
5
  License: Commercial - see LICENSE.txt
6
6
  Author: SeshatLabs
@@ -25,6 +25,7 @@ Requires-Dist: loguru (>=0.7.3,<0.8.0)
25
25
  Requires-Dist: memory-profiler (>=0.61.0,<0.62.0)
26
26
  Requires-Dist: openai (>=1.73.0,<2.0.0)
27
27
  Requires-Dist: pandas (>=2.2.1,<3.0.0)
28
+ Requires-Dist: psycopg2 (>=2.9,<3.0) ; extra == "postgres-support"
28
29
  Requires-Dist: pyarmor (>=8.5.1,<9.0.0)
29
30
  Requires-Dist: pydantic (>=2.7.4,<3.0.0)
30
31
  Requires-Dist: pyspark (>=3.5.1,<4.0.0)
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "sdk-seshat-python"
3
- version = "0.3.14"
3
+ version = "0.3.16"
4
4
  description = "Seshat python SDK is a library to help create ML data pipelines."
5
5
  authors = ["SeshatLabs <info@seshatlabs.xyz>"]
6
6
  packages = [{ include = "seshat", from = "." }]
@@ -33,11 +33,12 @@ langchain-community = "^0.3.21"
33
33
  langchain-openai = "^0.3.12"
34
34
  pyarmor = "^8.5.1"
35
35
  croniter = "^6.0.0"
36
+ psycopg2 = { version = "^2.9", optional = true }
36
37
  python-logstash-async = "^4.0.2"
37
38
 
38
39
  [tool.poetry.extras]
39
40
  flipside_support = ["flipside"]
40
- postgres_support = ["psycopg2-binary"]
41
+ postgres_support = ["psycopg2"]
41
42
 
42
43
  [tool.poetry.group.dev.dependencies]
43
44
  flake8 = "^7.0.0"
@@ -14,3 +14,7 @@ class FlipSideApiError(Exception):
14
14
  def __init__(self, e):
15
15
  message = f"An api error occurred: {str(e)}"
16
16
  super().__init__(message)
17
+
18
+
19
+ class PostgresConstraintError(Exception):
20
+ pass
@@ -1,7 +1,15 @@
1
- from typing import Optional, Callable
1
+ from typing import Callable, Optional
2
2
 
3
3
  import sqlalchemy as db
4
- from sqlalchemy import Engine, create_engine, inspect, Column, MetaData, Table
4
+ from sqlalchemy import (
5
+ Column,
6
+ Engine,
7
+ MetaData,
8
+ PrimaryKeyConstraint,
9
+ Table,
10
+ create_engine,
11
+ inspect,
12
+ )
5
13
 
6
14
  from seshat.transformer.schema import Schema
7
15
 
@@ -78,11 +86,21 @@ class SQLMixin:
78
86
 
79
87
  def create_table(self, schema: Schema, table: str):
80
88
  table_columns = []
89
+ pk_cols = []
81
90
  for col in schema.cols:
82
91
  col_name = col.to or col.original
83
92
  col_type = getattr(db, col.dtype or "String")
84
- table_columns.append(Column(col_name, col_type))
85
- _, metadata = self.get_table(table, False, *table_columns, extend_existing=True)
93
+ table_columns.append(Column(col_name, col_type, primary_key=col.is_id))
94
+ if col.is_id:
95
+ pk_cols.append(col_name)
96
+ constraints = []
97
+ if pk_cols:
98
+ constraints.append(
99
+ PrimaryKeyConstraint(*pk_cols, name=f"{table}_pk_{'_'.join(pk_cols)}")
100
+ )
101
+ _, metadata = self.get_table(
102
+ table, False, *table_columns, *constraints, extend_existing=True
103
+ )
86
104
  metadata.create_all(self.get_engine())
87
105
 
88
106
  def get_table(self, table_name, autoload, *args, **kwargs):
@@ -2,27 +2,43 @@ import hashlib
2
2
  import statistics
3
3
  from typing import List
4
4
 
5
- from sqlalchemy import Index
6
- from sqlalchemy import inspect, select, update, and_, true
5
+ import sqlalchemy as db
6
+ from sqlalchemy import (
7
+ Column,
8
+ Index,
9
+ MetaData,
10
+ Table,
11
+ and_,
12
+ inspect,
13
+ select,
14
+ true,
15
+ update,
16
+ )
7
17
 
8
18
  from seshat.data_class import SFrame
9
19
  from seshat.general.exceptions import InvalidArgumentsError
10
20
  from seshat.source.mixins import SQLMixin
11
21
  from seshat.source.saver import Saver
12
- from seshat.source.saver.base import SaveConfig
22
+ from seshat.source.saver.base import POSTGRES, SaveConfig
13
23
  from seshat.source.saver.utils import PostgresUtils
24
+ from seshat.transformer.schema import Schema
14
25
  from seshat.transformer.schema.base import UpdateFuncs
26
+ from seshat.transformer.trimmer.base import NaNTrimmer
15
27
 
16
28
 
17
29
  class SQLDBSaver(SQLMixin, Saver):
18
30
  def save(self, sf: SFrame, *args, **kwargs):
31
+
19
32
  for config in self.save_configs:
20
33
  self.fill_cols_to_field(config.schema)
21
34
  self.ensure_table_exists(config.table, config.schema)
22
35
  self.create_index(config)
36
+
37
+ selected_sf = sf.get(config.sf_key)
38
+ selected_sf = self.drop_nan_ids(selected_sf, config.schema)
39
+
23
40
  if config.clear_table:
24
41
  self.delete(config.table)
25
- selected_sf = sf.get(config.sf_key)
26
42
  if config.strategy == "update":
27
43
  self.update(selected_sf, config)
28
44
  elif config.strategy == "copy":
@@ -30,15 +46,25 @@ class SQLDBSaver(SQLMixin, Saver):
30
46
  else:
31
47
  self.insert(selected_sf, config)
32
48
 
49
+ def ensure_table_exists(self, table: str, schema: Schema):
50
+ engine = self.get_engine()
51
+ if table in inspect(engine).get_table_names():
52
+ return
53
+ self.create_table(schema, table)
54
+
55
+ def create_table(self, schema: Schema, table: str):
56
+ table_columns = []
57
+ for col in schema.cols:
58
+ col_name = col.to
59
+ col_type = getattr(db, col.dtype or "String")
60
+ table_columns.append(Column(col_name, col_type))
61
+ _, metadata = self.get_table(table, False, *table_columns, extend_existing=True)
62
+ metadata.create_all(self.get_engine())
63
+
33
64
  def delete(self, table_name):
34
65
  table, _ = self.get_table(table_name, autoload=True)
35
66
  self.write_on_db(table.delete())
36
67
 
37
- def drop_table(self, table_name):
38
- if table_name in inspect(self.get_engine()).get_table_names():
39
- table, _ = self.get_table(table_name, autoload=True)
40
- table.drop(self.get_engine())
41
-
42
68
  def insert(self, selected_sf: SFrame, config: SaveConfig):
43
69
  values = self.prepare_sf_to_insert(selected_sf, config).to_dict()
44
70
  table, _ = self.get_table(config.table, autoload=True)
@@ -51,7 +77,18 @@ class SQLDBSaver(SQLMixin, Saver):
51
77
  def update(self, selected_sf: SFrame, config: SaveConfig):
52
78
  table, _ = self.get_table(config.table, autoload=True)
53
79
  values = selected_sf.to_dict()
54
- id_cols = config.schema.get_id(return_first=False)
80
+ id_cols_schema = config.schema.get_id(return_first=False)
81
+
82
+ if self.db_type == POSTGRES:
83
+ try:
84
+ PostgresUtils.ensure_unique_constraint(
85
+ self.get_engine(), table, config.schema
86
+ )
87
+ stmt = PostgresUtils.generate_upsert_stmt(table, values, config)
88
+ self.write_on_db(stmt)
89
+ return
90
+ except Exception:
91
+ pass
55
92
 
56
93
  rows_to_update = self._get_existing_rows(table, values, config.schema)
57
94
  rows_to_create = []
@@ -59,7 +96,7 @@ class SQLDBSaver(SQLMixin, Saver):
59
96
 
60
97
  if rows_to_create:
61
98
  self.write_on_db(table.insert(), rows_to_create)
62
- db_id_cols = tuple(getattr(table.c, id_col.to) for id_col in id_cols)
99
+ db_id_cols = tuple(getattr(table.c, id_col.to) for id_col in id_cols_schema)
63
100
  for row in rows_to_update:
64
101
  condition = None
65
102
  for db_id in db_id_cols:
@@ -68,7 +105,6 @@ class SQLDBSaver(SQLMixin, Saver):
68
105
  condition = new_condition
69
106
  else:
70
107
  condition = and_(condition, new_condition)
71
-
72
108
  update_query = update(table).where(condition).values(row)
73
109
  self.write_on_db(update_query)
74
110
 
@@ -137,6 +173,12 @@ class SQLDBSaver(SQLMixin, Saver):
137
173
  )
138
174
  return self.get_from_db(query)
139
175
 
176
+ def get_table(self, table_name, autoload, *args, **kwargs):
177
+ metadata = MetaData()
178
+ if autoload:
179
+ kwargs.setdefault("autoload_with", self.get_engine())
180
+ return Table(table_name, metadata, *args, **kwargs), metadata
181
+
140
182
  def get_from_db(self, query):
141
183
  with self.get_engine().connect() as conn:
142
184
  result = conn.execute(query)
@@ -177,5 +219,15 @@ class SQLDBSaver(SQLMixin, Saver):
177
219
  joined_cols = "-".join(set(columns))
178
220
  return hashlib.sha256(joined_cols.encode("utf-8")).hexdigest()
179
221
 
222
+ def drop_nan_ids(self, sf: SFrame, schema: Schema):
223
+ seen = set()
224
+ for col in schema.get_id(return_first=False):
225
+ if col.original in seen:
226
+ continue
227
+ seen.add(col.original)
228
+ trimmer = NaNTrimmer(subset=[col.original])
229
+ sf = trimmer(sf)
230
+ return sf
231
+
180
232
  def calculate_complexity(self):
181
233
  return 60
@@ -0,0 +1,105 @@
1
+ import io
2
+
3
+ import pandas as pd
4
+ from sqlalchemy import func, inspect, text
5
+ from sqlalchemy.dialects import postgresql as pg
6
+ from sqlalchemy.engine import Engine
7
+
8
+ from seshat.source.exceptions import PostgresConstraintError
9
+
10
+ COPY_QUERY = (
11
+ "COPY {table_name} FROM STDIN WITH (FORMAT csv, HEADER TRUE, DELIMITER '\t');"
12
+ )
13
+
14
+
15
+ class PostgresUtils:
16
+ @staticmethod
17
+ def copy(engine: Engine, table_name: str, df: pd.DataFrame):
18
+ csv_content = io.StringIO()
19
+ df.to_csv(csv_content, sep="\t", header=True, index=False)
20
+ csv_content.seek(0)
21
+ conn = engine.raw_connection()
22
+ cur = conn.cursor()
23
+ cur.copy_expert(COPY_QUERY.format(table_name=table_name), csv_content)
24
+ conn.commit()
25
+ cur.close()
26
+ conn.close()
27
+
28
+ @staticmethod
29
+ def get_constraint_name(table_name, id_columns):
30
+ # Generate a unique constraint name based on table and id columns
31
+ cols_part = "_".join(sorted(id_columns))
32
+ return f"uq_{table_name}_{cols_part}"
33
+
34
+ @staticmethod
35
+ def constraint_exists(engine, table, constraint_name):
36
+ # Check if a unique constraint exists on a table
37
+ insp = inspect(engine)
38
+ for cons in insp.get_unique_constraints(table.name):
39
+ if cons["name"] == constraint_name:
40
+ return True
41
+ return False
42
+
43
+ @staticmethod
44
+ def ensure_unique_constraint(engine, table, schema):
45
+ # Ensure a unique constraint exists for the given table and schema
46
+ id_cols = [col.to for col in schema.get_id(return_first=False)]
47
+ constraint_name = PostgresUtils.get_constraint_name(table.name, id_cols)
48
+ if PostgresUtils.constraint_exists(engine, table, constraint_name):
49
+ return True
50
+ try:
51
+ sql = 'ALTER TABLE "{table_name}" ADD CONSTRAINT "{constraint_name}" UNIQUE ({cols});'.format(
52
+ table_name=table.name,
53
+ constraint_name=constraint_name,
54
+ cols=", ".join(['"{}"'.format(col) for col in id_cols]),
55
+ )
56
+ with engine.connect() as conn:
57
+ conn.execute(text(sql))
58
+ return True
59
+ except Exception as e:
60
+ if "duplicate" in str(e).lower() or "unique" in str(e).lower():
61
+ raise PostgresConstraintError(
62
+ f"Cannot create unique constraint '{constraint_name}' on table '{table.name}': "
63
+ f"duplicate data exists for columns {id_cols}."
64
+ )
65
+ elif "permission" in str(e).lower() or "not allowed" in str(e).lower():
66
+ raise PostgresConstraintError(
67
+ f"Insufficient privileges to create unique constraint '{constraint_name}' on table '{table.name}'."
68
+ )
69
+ else:
70
+ raise
71
+
72
+ @staticmethod
73
+ def generate_upsert_stmt(table, values, config):
74
+ # Perform an upsert (insert or update) on a Postgres table with custom update functions
75
+ id_cols = config.schema.get_id(return_first=False)
76
+ update_cols = [col for col in config.schema.cols if not col.is_id]
77
+ col_to_func_map = {
78
+ col.original: col.update_func or "replace"
79
+ for col in config.schema.cols
80
+ if not col.is_id
81
+ }
82
+ upsert_rows = []
83
+ for row in values:
84
+ upsert_row = {col.to: row[col.original] for col in config.schema.cols}
85
+ upsert_rows.append(upsert_row)
86
+ insert_stmt = pg.insert(table)
87
+ update_dict = {}
88
+ for col in update_cols:
89
+ func_name = col_to_func_map.get(col.original)
90
+ if func_name == "sum":
91
+ update_dict[col.to] = func.coalesce(table.c[col.to], 0) + func.coalesce(
92
+ insert_stmt.excluded[col.to], 0
93
+ )
94
+ elif func_name == "mean":
95
+ update_dict[col.to] = (
96
+ func.coalesce(table.c[col.to], 0)
97
+ + func.coalesce(insert_stmt.excluded[col.to], 0)
98
+ ) / 2
99
+ else:
100
+ update_dict[col.to] = insert_stmt.excluded[col.to]
101
+ conflict_cols = [col.to for col in id_cols]
102
+ stmt = insert_stmt.values(upsert_rows)
103
+ return stmt.on_conflict_do_update(
104
+ index_elements=conflict_cols, set_=update_dict
105
+ )
@@ -1,22 +0,0 @@
1
- import io
2
-
3
- import pandas as pd
4
- from sqlalchemy.engine import Engine
5
-
6
- COPY_QUERY = (
7
- "COPY {table_name} FROM STDIN WITH (FORMAT csv, HEADER TRUE, DELIMITER '\t');"
8
- )
9
-
10
-
11
- class PostgresUtils:
12
- @staticmethod
13
- def copy(engine: Engine, table_name: str, df: pd.DataFrame):
14
- csv_content = io.StringIO()
15
- df.to_csv(csv_content, sep="\t", header=True, index=False)
16
- csv_content.seek(0)
17
- conn = engine.raw_connection()
18
- cur = conn.cursor()
19
- cur.copy_expert(COPY_QUERY.format(table_name=table_name), csv_content)
20
- conn.commit()
21
- cur.close()
22
- conn.close()