nfscache 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nfscache/__init__.py +3 -0
- nfscache/data/__init__.py +0 -0
- nfscache/data/data_container.py +29 -0
- nfscache/data/data_holder.py +23 -0
- nfscache/database/__init__.py +0 -0
- nfscache/database/oracle_env.py +77 -0
- nfscache/database/oracle_pool.py +69 -0
- nfscache/database/oracle_read.py +100 -0
- nfscache/database/oracle_write.py +182 -0
- nfscache/database/oracle_write_container.py +186 -0
- nfscache/nfs_cache.py +836 -0
- nfscache/util/__init__.py +0 -0
- nfscache/util/generate_parquets.py +164 -0
- nfscache/util/main.py +80 -0
- nfscache/util/swarm_file.py +204 -0
- nfscache/util/swarm_sql.py +328 -0
- nfscache-0.1.0.dist-info/METADATA +284 -0
- nfscache-0.1.0.dist-info/RECORD +21 -0
- nfscache-0.1.0.dist-info/WHEEL +4 -0
- nfscache-0.1.0.dist-info/entry_points.txt +2 -0
- nfscache-0.1.0.dist-info/licenses/LICENSE +21 -0
nfscache/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import polars as pl
|
|
4
|
+
|
|
5
|
+
from nfscache.data.data_holder import DataHolder
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class DataContainer:
|
|
9
|
+
__slots__ = ("data",)
|
|
10
|
+
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
input_data: dict[str, Any],
|
|
14
|
+
) -> None:
|
|
15
|
+
self.data = DataHolder()
|
|
16
|
+
self.data.headers = tuple(input_data["headers"])
|
|
17
|
+
self.data.rows_data_pl = input_data["data"]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
if __name__ == "__main__":
|
|
21
|
+
headers = ("COL1", "COL2", "COL3")
|
|
22
|
+
rows = [(1.1, 1, "A"), (2.2, 2, "B"), (3.3, 3, "C"), (4.4, 4, "D")]
|
|
23
|
+
INPUT_DATA: dict[str, Any] = {
|
|
24
|
+
"headers": headers,
|
|
25
|
+
"data": pl.DataFrame(rows, schema=headers, orient="row"),
|
|
26
|
+
}
|
|
27
|
+
data = DataContainer(INPUT_DATA)
|
|
28
|
+
print("headers:", data.data.headers)
|
|
29
|
+
print("table:", data.data.rows_data_pl)
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import polars as pl
|
|
4
|
+
|
|
5
|
+
class DataHolder:
|
|
6
|
+
headers: tuple
|
|
7
|
+
rows_data_pl: pl.DataFrame | None = None
|
|
8
|
+
|
|
9
|
+
def __init__(self):
|
|
10
|
+
self.headers = tuple()
|
|
11
|
+
self.rows_data_pl = pl.DataFrame()
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
if __name__ == "__main__":
|
|
15
|
+
INPUT_DATA: dict[str, Any] = {
|
|
16
|
+
"headers": ("COL1", "COL2", "COL3"),
|
|
17
|
+
"data": [(1.1, 1, "A"), (2.2, 2, "B"), (3.3, 3, "C"), (4.4, 4, "D")],
|
|
18
|
+
}
|
|
19
|
+
data = DataHolder()
|
|
20
|
+
data.headers = tuple(INPUT_DATA["headers"])
|
|
21
|
+
data.rows_data_pl = pl.DataFrame(INPUT_DATA["data"], schema=INPUT_DATA["headers"], orient="row")
|
|
22
|
+
print("headers:", data.headers)
|
|
23
|
+
print("table:", data.rows_data_pl)
|
|
File without changes
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
from argparse import Namespace
|
|
2
|
+
from collections.abc import Callable, Mapping
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
EnvField = tuple[str, Callable[[str], object]]
|
|
6
|
+
|
|
7
|
+
ORACLE_ENV_FIELDS: Mapping[str, EnvField] = {
|
|
8
|
+
"ORACLE_HOST": ("host", str),
|
|
9
|
+
"ORACLE_PORT": ("port", int),
|
|
10
|
+
"ORACLE_SERVICE": ("service", str),
|
|
11
|
+
"ORACLE_USER": ("user", str),
|
|
12
|
+
"ORACLE_PASSWORD": ("password", str),
|
|
13
|
+
"ORACLE_TABLE": ("table", str),
|
|
14
|
+
"ORACLE_BATCH_SIZE": ("batch_size", int),
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def find_dotenv() -> Path | None:
|
|
19
|
+
for path in (Path.cwd(), *Path.cwd().parents):
|
|
20
|
+
dotenv_path = path / ".env"
|
|
21
|
+
try:
|
|
22
|
+
if dotenv_path.is_file():
|
|
23
|
+
return dotenv_path
|
|
24
|
+
except OSError:
|
|
25
|
+
# Unreadable .env (e.g. restricted perms): treat as absent.
|
|
26
|
+
continue
|
|
27
|
+
return None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def read_dotenv() -> dict[str, str]:
|
|
31
|
+
dotenv_path = find_dotenv()
|
|
32
|
+
if dotenv_path is None:
|
|
33
|
+
return {}
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
text = dotenv_path.read_text(encoding="utf-8")
|
|
37
|
+
except OSError:
|
|
38
|
+
# Fall back to defaults rather than crashing on an unreadable .env.
|
|
39
|
+
return {}
|
|
40
|
+
|
|
41
|
+
values: dict[str, str] = {}
|
|
42
|
+
for line in text.splitlines():
|
|
43
|
+
stripped = line.strip()
|
|
44
|
+
if not stripped or stripped.startswith("#"):
|
|
45
|
+
continue
|
|
46
|
+
if stripped.startswith("export "):
|
|
47
|
+
stripped = stripped.removeprefix("export ").strip()
|
|
48
|
+
if "=" not in stripped:
|
|
49
|
+
continue
|
|
50
|
+
|
|
51
|
+
key, value = stripped.split("=", 1)
|
|
52
|
+
key = key.strip()
|
|
53
|
+
value = value.strip()
|
|
54
|
+
if (
|
|
55
|
+
len(value) >= 2
|
|
56
|
+
and value[0] == value[-1]
|
|
57
|
+
and value[0] in {"'", '"'}
|
|
58
|
+
):
|
|
59
|
+
value = value[1:-1]
|
|
60
|
+
values[key] = value
|
|
61
|
+
|
|
62
|
+
return values
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def apply_dotenv(
|
|
66
|
+
args: Namespace,
|
|
67
|
+
*,
|
|
68
|
+
fields: Mapping[str, EnvField] = ORACLE_ENV_FIELDS,
|
|
69
|
+
) -> None:
|
|
70
|
+
values = read_dotenv()
|
|
71
|
+
for env_key, (arg_name, coerce) in fields.items():
|
|
72
|
+
if env_key not in values:
|
|
73
|
+
continue
|
|
74
|
+
value = values[env_key]
|
|
75
|
+
if value == "":
|
|
76
|
+
continue
|
|
77
|
+
setattr(args, arg_name, coerce(value))
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""Process-local oracledb connection pool, exposed as an NFSCache `connect_factory`.
|
|
2
|
+
|
|
3
|
+
`NFSCache.connect_factory` is an opaque `Callable[[], connection]` that the cache
|
|
4
|
+
uses as `with connect_factory() as conn:`. A pooled connection's `__exit__`
|
|
5
|
+
*releases* it back to the pool instead of closing the socket, so wiring a pool in
|
|
6
|
+
here removes the per-call `oracledb.connect` cost on every version probe and warm
|
|
7
|
+
hit while keeping `nfs_cache.py` free of any oracledb dependency.
|
|
8
|
+
|
|
9
|
+
Pools are cached per process (keyed by pid + DSN + user): connections are not
|
|
10
|
+
shareable across processes (ProcessPoolExecutor workers each build their own), and
|
|
11
|
+
a forked child must not reuse the parent's pool. The pid in the key guards that.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import argparse
|
|
15
|
+
import os
|
|
16
|
+
import threading
|
|
17
|
+
|
|
18
|
+
import oracledb
|
|
19
|
+
|
|
20
|
+
_pools: dict[tuple[int, str, str], "oracledb.ConnectionPool"] = {}
|
|
21
|
+
_lock = threading.Lock()
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _dsn(args: argparse.Namespace) -> str:
|
|
25
|
+
return f"{args.host}:{args.port}/{args.service}"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_pool(
|
|
29
|
+
args: argparse.Namespace,
|
|
30
|
+
*,
|
|
31
|
+
min_size: int = 1,
|
|
32
|
+
max_size: int = 4,
|
|
33
|
+
) -> "oracledb.ConnectionPool":
|
|
34
|
+
"""Return a process-local pool for `args`, creating it once per process."""
|
|
35
|
+
key = (os.getpid(), _dsn(args), args.user)
|
|
36
|
+
pool = _pools.get(key)
|
|
37
|
+
if pool is not None:
|
|
38
|
+
return pool
|
|
39
|
+
|
|
40
|
+
with _lock:
|
|
41
|
+
pool = _pools.get(key)
|
|
42
|
+
if pool is None:
|
|
43
|
+
pool = oracledb.create_pool(
|
|
44
|
+
user=args.user,
|
|
45
|
+
password=args.password,
|
|
46
|
+
dsn=_dsn(args),
|
|
47
|
+
min=min_size,
|
|
48
|
+
max=max_size,
|
|
49
|
+
)
|
|
50
|
+
_pools[key] = pool
|
|
51
|
+
return pool
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def make_pool_factory(
|
|
55
|
+
args: argparse.Namespace,
|
|
56
|
+
*,
|
|
57
|
+
min_size: int = 1,
|
|
58
|
+
max_size: int = 4,
|
|
59
|
+
):
|
|
60
|
+
"""Build a `connect_factory` that acquires from a process-local pool.
|
|
61
|
+
|
|
62
|
+
The pool is created lazily on first call, so building the factory never
|
|
63
|
+
touches the database (safe to do at import time).
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
def factory() -> "oracledb.Connection":
|
|
67
|
+
return get_pool(args, min_size=min_size, max_size=max_size).acquire()
|
|
68
|
+
|
|
69
|
+
return factory
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import oracledb
|
|
3
|
+
import polars as pl
|
|
4
|
+
|
|
5
|
+
from nfscache.database.oracle_env import apply_dotenv
|
|
6
|
+
from nfscache.database.oracle_pool import make_pool_factory
|
|
7
|
+
from nfscache.util.main import nfscache
|
|
8
|
+
from nfscache.data.data_container import DataContainer
|
|
9
|
+
|
|
10
|
+
DEFAULT_BATCH_SIZE = 10000
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def oracle_args() -> argparse.Namespace:
|
|
14
|
+
args = argparse.Namespace(
|
|
15
|
+
host="localhost",
|
|
16
|
+
port=1521,
|
|
17
|
+
service="FREEPDB1",
|
|
18
|
+
user="SOMEUSER",
|
|
19
|
+
password="cache",
|
|
20
|
+
batch_size=DEFAULT_BATCH_SIZE,
|
|
21
|
+
)
|
|
22
|
+
apply_dotenv(args)
|
|
23
|
+
return args
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def connect(args: argparse.Namespace) -> oracledb.Connection:
|
|
27
|
+
dsn = f"{args.host}:{args.port}/{args.service}"
|
|
28
|
+
return oracledb.connect(
|
|
29
|
+
user=args.user,
|
|
30
|
+
password=args.password,
|
|
31
|
+
dsn=dsn,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _fetch_data_container(
|
|
36
|
+
connection: oracledb.Connection,
|
|
37
|
+
sql: str,
|
|
38
|
+
*,
|
|
39
|
+
batch_size: int,
|
|
40
|
+
) -> DataContainer:
|
|
41
|
+
batches: list[pl.DataFrame] = []
|
|
42
|
+
with connection.cursor() as cursor:
|
|
43
|
+
cursor.arraysize = batch_size
|
|
44
|
+
cursor.execute(sql)
|
|
45
|
+
headers = tuple(column[0] for column in cursor.description)
|
|
46
|
+
|
|
47
|
+
while True:
|
|
48
|
+
rows = cursor.fetchmany(batch_size)
|
|
49
|
+
if not rows:
|
|
50
|
+
break
|
|
51
|
+
batches.append(pl.DataFrame(rows, schema=headers, orient="row"))
|
|
52
|
+
|
|
53
|
+
df = pl.concat(batches) if batches else pl.DataFrame(schema=headers)
|
|
54
|
+
return DataContainer({"headers": headers, "data": df})
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
# Pool-backed factory: the version probe (every call) and the cold-load fetch
|
|
58
|
+
# both borrow from one process-local pool instead of opening a fresh connection.
|
|
59
|
+
nfscache.connect_factory = make_pool_factory(oracle_args())
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@nfscache.sql
|
|
63
|
+
def read_data_container(sql: str) -> DataContainer:
|
|
64
|
+
# This body only runs on a cache miss, so reaching it means a live read.
|
|
65
|
+
print(f"Serving from Oracle (cache miss): {sql}", flush=True)
|
|
66
|
+
args = oracle_args()
|
|
67
|
+
with nfscache.connect_factory() as connection:
|
|
68
|
+
return _fetch_data_container(
|
|
69
|
+
connection,
|
|
70
|
+
sql,
|
|
71
|
+
batch_size=int(args.batch_size),
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def parse_args() -> argparse.Namespace:
|
|
76
|
+
parser = argparse.ArgumentParser(
|
|
77
|
+
description="Read Oracle SQL into a DataContainer."
|
|
78
|
+
)
|
|
79
|
+
parser.add_argument("sql")
|
|
80
|
+
parser.add_argument("--host", default="localhost")
|
|
81
|
+
parser.add_argument("--port", type=int, default=1521)
|
|
82
|
+
parser.add_argument("--service", default="FREEPDB1")
|
|
83
|
+
parser.add_argument("--user", default="SOMEUSER")
|
|
84
|
+
parser.add_argument("--password", default="cache")
|
|
85
|
+
parser.add_argument("--batch-size", type=int, default=10000)
|
|
86
|
+
args = parser.parse_args()
|
|
87
|
+
apply_dotenv(args)
|
|
88
|
+
return args
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def main() -> int:
|
|
92
|
+
args = parse_args()
|
|
93
|
+
# Go through the cache so the run reports Oracle (miss) vs cache (hit).
|
|
94
|
+
container = read_data_container(args.sql)
|
|
95
|
+
print("table:", container.data.rows_data_pl)
|
|
96
|
+
return 0
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
if __name__ == "__main__":
|
|
100
|
+
raise SystemExit(main())
|
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import re
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import oracledb
|
|
6
|
+
import polars as pl
|
|
7
|
+
|
|
8
|
+
from nfscache.database.oracle_env import apply_dotenv
|
|
9
|
+
from nfscache.data.data_container import DataContainer
|
|
10
|
+
|
|
11
|
+
IDENTIFIER_RE = re.compile(r"^[A-Za-z][A-Za-z0-9_$#]{0,127}$")
|
|
12
|
+
|
|
13
|
+
INTEGER_TYPES = {
|
|
14
|
+
pl.Int8,
|
|
15
|
+
pl.Int16,
|
|
16
|
+
pl.Int32,
|
|
17
|
+
pl.Int64,
|
|
18
|
+
pl.UInt8,
|
|
19
|
+
pl.UInt16,
|
|
20
|
+
pl.UInt32,
|
|
21
|
+
pl.UInt64,
|
|
22
|
+
}
|
|
23
|
+
FLOAT_TYPES = {pl.Float32, pl.Float64}
|
|
24
|
+
STRING_TYPES = {pl.String, pl.Categorical, pl.Enum}
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def oracle_identifier(name: str) -> str:
|
|
28
|
+
if not IDENTIFIER_RE.fullmatch(name):
|
|
29
|
+
raise ValueError(f"Invalid Oracle identifier: {name!r}")
|
|
30
|
+
return name.upper()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def table_name_from_path(path: Path) -> str:
|
|
34
|
+
return oracle_identifier(path.stem)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def oracle_type(dtype: pl.DataType) -> str:
|
|
38
|
+
if dtype in INTEGER_TYPES:
|
|
39
|
+
return "NUMBER(38)"
|
|
40
|
+
if dtype in FLOAT_TYPES:
|
|
41
|
+
return "BINARY_DOUBLE"
|
|
42
|
+
if dtype in STRING_TYPES:
|
|
43
|
+
return "VARCHAR2(4000)"
|
|
44
|
+
if dtype == pl.Boolean:
|
|
45
|
+
return "NUMBER(1)"
|
|
46
|
+
if dtype == pl.Date:
|
|
47
|
+
return "DATE"
|
|
48
|
+
if dtype == pl.Datetime:
|
|
49
|
+
return "TIMESTAMP"
|
|
50
|
+
return "CLOB"
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def read_data_container(path: Path) -> DataContainer:
|
|
54
|
+
print(f"Reading: {path}...")
|
|
55
|
+
df = pl.read_parquet(path)
|
|
56
|
+
return DataContainer({"headers": tuple(df.columns), "data": df})
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def connect(args: argparse.Namespace) -> oracledb.Connection:
|
|
60
|
+
dsn = f"{args.host}:{args.port}/{args.service}"
|
|
61
|
+
return oracledb.connect(
|
|
62
|
+
user=args.user,
|
|
63
|
+
password=args.password,
|
|
64
|
+
dsn=dsn,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def current_scn(connection: oracledb.Connection) -> int:
|
|
69
|
+
with connection.cursor() as cursor:
|
|
70
|
+
scn, = cursor.execute("select current_scn from v$database").fetchone()
|
|
71
|
+
return int(scn)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def drop_table_if_exists(connection: oracledb.Connection, table_name: str) -> None:
|
|
75
|
+
with connection.cursor() as cursor:
|
|
76
|
+
try:
|
|
77
|
+
cursor.execute(f"drop table {table_name} purge")
|
|
78
|
+
except oracledb.DatabaseError as exc:
|
|
79
|
+
error, = exc.args
|
|
80
|
+
if getattr(error, "code", None) != 942:
|
|
81
|
+
raise
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def create_table(
|
|
85
|
+
connection: oracledb.Connection,
|
|
86
|
+
table_name: str,
|
|
87
|
+
df: pl.DataFrame,
|
|
88
|
+
) -> list[str]:
|
|
89
|
+
columns = [oracle_identifier(column) for column in df.columns]
|
|
90
|
+
definitions = [
|
|
91
|
+
f"{column} {oracle_type(dtype)}"
|
|
92
|
+
for column, dtype in zip(columns, df.dtypes, strict=True)
|
|
93
|
+
]
|
|
94
|
+
ddl = f"create table {table_name} ({', '.join(definitions)})"
|
|
95
|
+
with connection.cursor() as cursor:
|
|
96
|
+
cursor.execute(ddl)
|
|
97
|
+
return columns
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def insert_data_container(
|
|
101
|
+
connection: oracledb.Connection,
|
|
102
|
+
table_name: str,
|
|
103
|
+
data_container: DataContainer,
|
|
104
|
+
*,
|
|
105
|
+
batch_size: int,
|
|
106
|
+
) -> tuple[int, int]:
|
|
107
|
+
df = data_container.data.rows_data_pl
|
|
108
|
+
if not isinstance(df, pl.DataFrame):
|
|
109
|
+
raise TypeError("DataContainer.data.rows_data_pl must be a Polars DataFrame")
|
|
110
|
+
|
|
111
|
+
columns = create_table(connection, table_name, df)
|
|
112
|
+
placeholders = ", ".join(f":{index}" for index in range(1, len(columns) + 1))
|
|
113
|
+
sql = (
|
|
114
|
+
f"insert into {table_name} ({', '.join(columns)}) "
|
|
115
|
+
f"values ({placeholders})"
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
inserted = 0
|
|
119
|
+
with connection.cursor() as cursor:
|
|
120
|
+
for batch in df.iter_slices(n_rows=batch_size):
|
|
121
|
+
rows = list(batch.iter_rows(named=False))
|
|
122
|
+
if not rows:
|
|
123
|
+
continue
|
|
124
|
+
cursor.executemany(sql, rows)
|
|
125
|
+
inserted += len(rows)
|
|
126
|
+
connection.commit()
|
|
127
|
+
return inserted, len(columns)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def parse_args() -> argparse.Namespace:
|
|
131
|
+
parser = argparse.ArgumentParser(
|
|
132
|
+
description="Read a parquet-backed DataContainer and write it to Oracle."
|
|
133
|
+
)
|
|
134
|
+
parser.add_argument("parquet_path", type=Path)
|
|
135
|
+
parser.add_argument("--host", default="localhost")
|
|
136
|
+
parser.add_argument("--port", type=int, default=1521)
|
|
137
|
+
parser.add_argument("--service", default="FREEPDB1")
|
|
138
|
+
parser.add_argument("--user", default="SOMEUSER")
|
|
139
|
+
parser.add_argument("--password", default="cache")
|
|
140
|
+
parser.add_argument(
|
|
141
|
+
"--table",
|
|
142
|
+
help="Oracle table name. Defaults to the parquet file stem.",
|
|
143
|
+
)
|
|
144
|
+
parser.add_argument("--batch-size", type=int, default=1000)
|
|
145
|
+
args = parser.parse_args()
|
|
146
|
+
apply_dotenv(args)
|
|
147
|
+
return args
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def main() -> int:
|
|
151
|
+
args = parse_args()
|
|
152
|
+
path = args.parquet_path
|
|
153
|
+
table_name = (
|
|
154
|
+
oracle_identifier(args.table)
|
|
155
|
+
if args.table is not None
|
|
156
|
+
else table_name_from_path(path)
|
|
157
|
+
)
|
|
158
|
+
data_container = read_data_container(path)
|
|
159
|
+
df = data_container.data.rows_data_pl
|
|
160
|
+
if not isinstance(df, pl.DataFrame):
|
|
161
|
+
raise TypeError("DataContainer.data.rows_data_pl must be a Polars DataFrame")
|
|
162
|
+
|
|
163
|
+
print(f"DataContainer: rows={df.height} cols={df.width}")
|
|
164
|
+
with connect(args) as connection:
|
|
165
|
+
before_scn = current_scn(connection)
|
|
166
|
+
print(f"Oracle current_scn before write: {before_scn}")
|
|
167
|
+
drop_table_if_exists(connection, table_name)
|
|
168
|
+
rows, cols = insert_data_container(
|
|
169
|
+
connection,
|
|
170
|
+
table_name,
|
|
171
|
+
data_container,
|
|
172
|
+
batch_size=int(args.batch_size),
|
|
173
|
+
)
|
|
174
|
+
after_scn = current_scn(connection)
|
|
175
|
+
|
|
176
|
+
print(f"Wrote DataContainer to {args.user.upper()}.{table_name}: rows={rows} cols={cols}")
|
|
177
|
+
print(f"Oracle current_scn after write: {after_scn}")
|
|
178
|
+
return 0
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
if __name__ == "__main__":
|
|
182
|
+
raise SystemExit(main())
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import re
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import oracledb
|
|
6
|
+
import polars as pl
|
|
7
|
+
|
|
8
|
+
from nfscache.database.oracle_env import apply_dotenv
|
|
9
|
+
from nfscache.data.data_container import DataContainer
|
|
10
|
+
from nfscache.util.generate_parquets import ensure_one_parquet
|
|
11
|
+
|
|
12
|
+
IDENTIFIER_RE = re.compile(r"^[A-Za-z][A-Za-z0-9_$#]{0,127}$")
|
|
13
|
+
|
|
14
|
+
INTEGER_TYPES = {
|
|
15
|
+
pl.Int8,
|
|
16
|
+
pl.Int16,
|
|
17
|
+
pl.Int32,
|
|
18
|
+
pl.Int64,
|
|
19
|
+
pl.UInt8,
|
|
20
|
+
pl.UInt16,
|
|
21
|
+
pl.UInt32,
|
|
22
|
+
pl.UInt64,
|
|
23
|
+
}
|
|
24
|
+
FLOAT_TYPES = {pl.Float32, pl.Float64}
|
|
25
|
+
STRING_TYPES = {pl.String, pl.Categorical, pl.Enum}
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def oracle_identifier(name: str) -> str:
|
|
29
|
+
if not IDENTIFIER_RE.fullmatch(name):
|
|
30
|
+
raise ValueError(f"Invalid Oracle identifier: {name!r}")
|
|
31
|
+
return name.upper()
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def oracle_type(dtype: pl.DataType) -> str:
|
|
35
|
+
if dtype in INTEGER_TYPES:
|
|
36
|
+
return "NUMBER(38)"
|
|
37
|
+
if dtype in FLOAT_TYPES:
|
|
38
|
+
return "BINARY_DOUBLE"
|
|
39
|
+
if dtype in STRING_TYPES:
|
|
40
|
+
return "VARCHAR2(4000)"
|
|
41
|
+
if dtype == pl.Boolean:
|
|
42
|
+
return "NUMBER(1)"
|
|
43
|
+
if dtype == pl.Date:
|
|
44
|
+
return "DATE"
|
|
45
|
+
if dtype == pl.Datetime:
|
|
46
|
+
return "TIMESTAMP"
|
|
47
|
+
return "CLOB"
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def connect(args: argparse.Namespace) -> oracledb.Connection:
|
|
51
|
+
dsn = f"{args.host}:{args.port}/{args.service}"
|
|
52
|
+
return oracledb.connect(
|
|
53
|
+
user=args.user,
|
|
54
|
+
password=args.password,
|
|
55
|
+
dsn=dsn,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def make_data_container(args: argparse.Namespace) -> tuple[Path, DataContainer]:
|
|
60
|
+
path = ensure_one_parquet(
|
|
61
|
+
out_dir=Path(args.data_dir),
|
|
62
|
+
base_name=f"ORACLE_TEST_{args.rows}.parquet",
|
|
63
|
+
prefix="A_",
|
|
64
|
+
n_rows=int(args.rows),
|
|
65
|
+
n_cols=int(args.cols),
|
|
66
|
+
seed=None,
|
|
67
|
+
float_scale=float(args.float_scale),
|
|
68
|
+
n_int_cols=int(args.n_int_cols),
|
|
69
|
+
n_str_cols=int(args.n_str_cols),
|
|
70
|
+
)
|
|
71
|
+
df = pl.read_parquet(path)
|
|
72
|
+
return path, DataContainer({"headers": tuple(df.columns), "data": df})
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def current_scn(connection: oracledb.Connection) -> int:
|
|
76
|
+
with connection.cursor() as cursor:
|
|
77
|
+
scn, = cursor.execute("select current_scn from v$database").fetchone()
|
|
78
|
+
return int(scn)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def drop_table_if_exists(connection: oracledb.Connection, table_name: str) -> None:
|
|
82
|
+
with connection.cursor() as cursor:
|
|
83
|
+
try:
|
|
84
|
+
cursor.execute(f"drop table {table_name} purge")
|
|
85
|
+
except oracledb.DatabaseError as exc:
|
|
86
|
+
error, = exc.args
|
|
87
|
+
if getattr(error, "code", None) != 942:
|
|
88
|
+
raise
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def create_table(
|
|
92
|
+
connection: oracledb.Connection,
|
|
93
|
+
table_name: str,
|
|
94
|
+
df: pl.DataFrame,
|
|
95
|
+
) -> list[str]:
|
|
96
|
+
columns = [oracle_identifier(column) for column in df.columns]
|
|
97
|
+
definitions = [
|
|
98
|
+
f"{column} {oracle_type(dtype)}"
|
|
99
|
+
for column, dtype in zip(columns, df.dtypes, strict=True)
|
|
100
|
+
]
|
|
101
|
+
ddl = f"create table {table_name} ({', '.join(definitions)})"
|
|
102
|
+
with connection.cursor() as cursor:
|
|
103
|
+
cursor.execute(ddl)
|
|
104
|
+
return columns
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def insert_data_container(
|
|
108
|
+
connection: oracledb.Connection,
|
|
109
|
+
table_name: str,
|
|
110
|
+
data_container: DataContainer,
|
|
111
|
+
*,
|
|
112
|
+
batch_size: int,
|
|
113
|
+
) -> tuple[int, int]:
|
|
114
|
+
df = data_container.data.rows_data_pl
|
|
115
|
+
if not isinstance(df, pl.DataFrame):
|
|
116
|
+
raise TypeError("DataContainer.data.rows_data_pl must be a Polars DataFrame")
|
|
117
|
+
|
|
118
|
+
columns = create_table(connection, table_name, df)
|
|
119
|
+
placeholders = ", ".join(f":{index}" for index in range(1, len(columns) + 1))
|
|
120
|
+
sql = (
|
|
121
|
+
f"insert into {table_name} ({', '.join(columns)}) "
|
|
122
|
+
f"values ({placeholders})"
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
inserted = 0
|
|
126
|
+
with connection.cursor() as cursor:
|
|
127
|
+
for batch in df.iter_slices(n_rows=batch_size):
|
|
128
|
+
rows = list(batch.iter_rows(named=False))
|
|
129
|
+
if not rows:
|
|
130
|
+
continue
|
|
131
|
+
cursor.executemany(sql, rows)
|
|
132
|
+
inserted += len(rows)
|
|
133
|
+
connection.commit()
|
|
134
|
+
return inserted, len(columns)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def parse_args() -> argparse.Namespace:
|
|
138
|
+
parser = argparse.ArgumentParser(
|
|
139
|
+
description="Generate a DataContainer and write it to the local Oracle DB."
|
|
140
|
+
)
|
|
141
|
+
parser.add_argument("--host", default="localhost")
|
|
142
|
+
parser.add_argument("--port", type=int, default=1521)
|
|
143
|
+
parser.add_argument("--service", default="FREEPDB1")
|
|
144
|
+
parser.add_argument("--user", default="SOMEUSER")
|
|
145
|
+
parser.add_argument("--password", default="cache")
|
|
146
|
+
parser.add_argument("--table", default="DATA_CONTAINER_DEMO")
|
|
147
|
+
parser.add_argument("--rows", type=int, default=4096)
|
|
148
|
+
parser.add_argument("--cols", type=int, default=20)
|
|
149
|
+
parser.add_argument("--batch-size", type=int, default=1000)
|
|
150
|
+
parser.add_argument("--data-dir", default="parquet")
|
|
151
|
+
parser.add_argument("--float-scale", type=float, default=5.0)
|
|
152
|
+
parser.add_argument("--n-int-cols", type=int, default=4)
|
|
153
|
+
parser.add_argument("--n-str-cols", type=int, default=8)
|
|
154
|
+
args = parser.parse_args()
|
|
155
|
+
apply_dotenv(args)
|
|
156
|
+
return args
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def main() -> int:
|
|
160
|
+
args = parse_args()
|
|
161
|
+
table_name = oracle_identifier(args.table)
|
|
162
|
+
path, data_container = make_data_container(args)
|
|
163
|
+
df = data_container.data.rows_data_pl
|
|
164
|
+
if not isinstance(df, pl.DataFrame):
|
|
165
|
+
raise TypeError("DataContainer.data.rows_data_pl must be a Polars DataFrame")
|
|
166
|
+
|
|
167
|
+
print(f"Generated: {path} rows={df.height} cols={df.width}")
|
|
168
|
+
with connect(args) as connection:
|
|
169
|
+
before_scn = current_scn(connection)
|
|
170
|
+
print(f"Oracle current_scn before write: {before_scn}")
|
|
171
|
+
drop_table_if_exists(connection, table_name)
|
|
172
|
+
rows, cols = insert_data_container(
|
|
173
|
+
connection,
|
|
174
|
+
table_name,
|
|
175
|
+
data_container,
|
|
176
|
+
batch_size=int(args.batch_size),
|
|
177
|
+
)
|
|
178
|
+
after_scn = current_scn(connection)
|
|
179
|
+
|
|
180
|
+
print(f"Wrote DataContainer to {args.user.upper()}.{table_name}: rows={rows} cols={cols}")
|
|
181
|
+
print(f"Oracle current_scn after write: {after_scn}")
|
|
182
|
+
return 0
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
if __name__ == "__main__":
|
|
186
|
+
raise SystemExit(main())
|