nanasqlite 1.3.3.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nanasqlite/__init__.py +52 -0
- nanasqlite/async_core.py +1456 -0
- nanasqlite/cache.py +335 -0
- nanasqlite/core.py +2336 -0
- nanasqlite/exceptions.py +117 -0
- nanasqlite/py.typed +0 -0
- nanasqlite/sql_utils.py +174 -0
- nanasqlite/utils.py +202 -0
- nanasqlite-1.3.3.dev4.dist-info/METADATA +413 -0
- nanasqlite-1.3.3.dev4.dist-info/RECORD +13 -0
- nanasqlite-1.3.3.dev4.dist-info/WHEEL +5 -0
- nanasqlite-1.3.3.dev4.dist-info/licenses/LICENSE +21 -0
- nanasqlite-1.3.3.dev4.dist-info/top_level.txt +1 -0
nanasqlite/core.py
ADDED
|
@@ -0,0 +1,2336 @@
|
|
|
1
|
+
"""
|
|
2
|
+
NanaSQLite: APSW SQLite-backed dict wrapper with memory caching.
|
|
3
|
+
|
|
4
|
+
通常のPython dictをラップし、操作時にSQLite永続化処理を行う。
|
|
5
|
+
- 書き込み: 即時SQLiteへ永続化
|
|
6
|
+
- 読み込み: デフォルトは遅延ロード(使用時)、一度読み込んだらメモリ管理
|
|
7
|
+
- 一括ロード: bulk_load=Trueで起動時に全データをメモリに展開
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import json
|
|
13
|
+
import logging
|
|
14
|
+
import os
|
|
15
|
+
import re
|
|
16
|
+
import threading
|
|
17
|
+
import warnings
|
|
18
|
+
import weakref
|
|
19
|
+
from collections.abc import Iterator, MutableMapping
|
|
20
|
+
from typing import Any, Literal
|
|
21
|
+
|
|
22
|
+
import apsw
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
from cryptography.fernet import Fernet
|
|
26
|
+
from cryptography.hazmat.primitives.ciphers.aead import AESGCM, ChaCha20Poly1305
|
|
27
|
+
HAS_CRYPTOGRAPHY = True
|
|
28
|
+
except ImportError:
|
|
29
|
+
HAS_CRYPTOGRAPHY = False
|
|
30
|
+
|
|
31
|
+
from .cache import CacheStrategy, CacheType, create_cache
|
|
32
|
+
from .exceptions import (
|
|
33
|
+
NanaSQLiteClosedError,
|
|
34
|
+
NanaSQLiteConnectionError,
|
|
35
|
+
NanaSQLiteDatabaseError,
|
|
36
|
+
NanaSQLiteTransactionError,
|
|
37
|
+
NanaSQLiteValidationError,
|
|
38
|
+
)
|
|
39
|
+
from .sql_utils import fast_validate_sql_chars, sanitize_sql_for_function_scan
|
|
40
|
+
|
|
41
|
+
# 識別子バリデーション用の正規表現パターン(英数字とアンダースコアのみ、数字で開始しない)
|
|
42
|
+
IDENTIFIER_PATTERN = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
|
|
43
|
+
|
|
44
|
+
logger = logging.getLogger(__name__)
|
|
45
|
+
|
|
46
|
+
# Optional fast JSON (orjson)
|
|
47
|
+
try:
|
|
48
|
+
import orjson # type: ignore
|
|
49
|
+
|
|
50
|
+
HAS_ORJSON = True
|
|
51
|
+
except Exception:
|
|
52
|
+
HAS_ORJSON = False
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class NanaSQLite(MutableMapping):
|
|
56
|
+
"""
|
|
57
|
+
APSW SQLite-backed dict wrapper with Security and Connection Enhancements (v1.2.0).
|
|
58
|
+
(APSW SQLiteをバックエンドとした、セキュリティ・接続管理強化版の辞書型ラッパー (v1.2.0))
|
|
59
|
+
|
|
60
|
+
Internally maintains a Python dict and synchronizes with SQLite during operations.
|
|
61
|
+
In v1.2.0, enhanced dynamic SQL validation, ReDoS protection, and strict connection management are introduced.
|
|
62
|
+
|
|
63
|
+
内部でPython dictを保持し、操作時にSQLiteとの同期を行います。
|
|
64
|
+
v1.2.0では、動的SQLのバリデーション強化、ReDoS対策、および厳格な接続管理が導入されています。
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
db_path: SQLiteデータベースファイルのパス
|
|
68
|
+
table: 使用するテーブル名 (デフォルト: "data")
|
|
69
|
+
bulk_load: Trueの場合、初期化時に全データをメモリに読み込む
|
|
70
|
+
strict_sql_validation: Trueの場合、未許可の関数等を含むクエリを拒否 (v1.2.0)
|
|
71
|
+
max_clause_length: SQL句の最大長(ReDoS対策、v1.2.0)
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
db_path: str,
|
|
77
|
+
table: str = "data",
|
|
78
|
+
bulk_load: bool = False,
|
|
79
|
+
optimize: bool = True,
|
|
80
|
+
cache_size_mb: int = 64,
|
|
81
|
+
strict_sql_validation: bool = True,
|
|
82
|
+
allowed_sql_functions: list[str] | None = None,
|
|
83
|
+
forbidden_sql_functions: list[str] | None = None,
|
|
84
|
+
max_clause_length: int | None = 1000,
|
|
85
|
+
cache_strategy: CacheType | Literal["unbounded", "lru", "ttl"] = CacheType.UNBOUNDED,
|
|
86
|
+
cache_size: int | None = None,
|
|
87
|
+
cache_ttl: float | None = None,
|
|
88
|
+
cache_persistence_ttl: bool = False,
|
|
89
|
+
encryption_key: str | bytes | None = None,
|
|
90
|
+
encryption_mode: Literal["aes-gcm", "chacha20", "fernet"] = "aes-gcm",
|
|
91
|
+
_shared_connection: apsw.Connection | None = None,
|
|
92
|
+
_shared_lock: threading.RLock | None = None,
|
|
93
|
+
):
|
|
94
|
+
"""
|
|
95
|
+
Args:
|
|
96
|
+
db_path: SQLiteデータベースファイルのパス
|
|
97
|
+
table: 使用するテーブル名 (デフォルト: "data")
|
|
98
|
+
bulk_load: Trueの場合、初期化時に全データをメモリに読み込む
|
|
99
|
+
optimize: Trueの場合、WALモードなど高速化設定を適用
|
|
100
|
+
cache_size_mb: SQLiteキャッシュサイズ(MB)、デフォルト64MB
|
|
101
|
+
strict_sql_validation: Trueの場合、未許可の関数等を含むクエリを拒否
|
|
102
|
+
allowed_sql_functions: 追加で許可するSQL関数のリスト
|
|
103
|
+
forbidden_sql_functions: 明示的に禁止するSQL関数のリスト
|
|
104
|
+
max_clause_length: SQL句の最大長(ReDoS対策)。Noneで制限なし
|
|
105
|
+
_shared_connection: 内部用:共有する接続(table()メソッドで使用)
|
|
106
|
+
_shared_lock: 内部用:共有するロック(table()メソッドで使用)
|
|
107
|
+
"""
|
|
108
|
+
self._db_path: str = db_path
|
|
109
|
+
self._table: str = table
|
|
110
|
+
|
|
111
|
+
# Encryption setup
|
|
112
|
+
self._encryption_key = encryption_key
|
|
113
|
+
self._encryption_mode = encryption_mode
|
|
114
|
+
self._fernet: Fernet | None = None
|
|
115
|
+
self._aead: AESGCM | ChaCha20Poly1305 | None = None
|
|
116
|
+
|
|
117
|
+
if encryption_key:
|
|
118
|
+
if not HAS_CRYPTOGRAPHY:
|
|
119
|
+
raise ImportError(
|
|
120
|
+
"Encryption requires the 'cryptography' library. "
|
|
121
|
+
"Install it with: pip install nanasqlite[encryption]"
|
|
122
|
+
)
|
|
123
|
+
# Support both str (base64) and bytes
|
|
124
|
+
key_bytes: bytes = encryption_key.encode("utf-8") if isinstance(encryption_key, str) else encryption_key
|
|
125
|
+
|
|
126
|
+
if encryption_mode == "fernet":
|
|
127
|
+
self._fernet = Fernet(key_bytes)
|
|
128
|
+
elif encryption_mode == "aes-gcm":
|
|
129
|
+
self._aead = AESGCM(key_bytes)
|
|
130
|
+
elif encryption_mode == "chacha20":
|
|
131
|
+
self._aead = ChaCha20Poly1305(key_bytes)
|
|
132
|
+
else:
|
|
133
|
+
raise ValueError(f"Unsupported encryption_mode: {encryption_mode}")
|
|
134
|
+
|
|
135
|
+
# Setup Persistence TTL callback if enabled
|
|
136
|
+
on_expire = None
|
|
137
|
+
if (cache_strategy == CacheType.TTL or cache_strategy == "ttl") and cache_persistence_ttl:
|
|
138
|
+
|
|
139
|
+
def _expire_callback(key: str, value: Any) -> None:
|
|
140
|
+
try:
|
|
141
|
+
# Use a new or shared connection to delete from DB
|
|
142
|
+
# Implementation detail: we need to be careful with locks
|
|
143
|
+
self._delete_from_db_on_expire(key)
|
|
144
|
+
except Exception as e:
|
|
145
|
+
logger.error(f"Failed to delete expired key '{key}' from DB: {e}")
|
|
146
|
+
|
|
147
|
+
on_expire = _expire_callback
|
|
148
|
+
|
|
149
|
+
self._cache: CacheStrategy = create_cache(cache_strategy, cache_size, ttl=cache_ttl, on_expire=on_expire)
|
|
150
|
+
self._data = self._cache.get_data()
|
|
151
|
+
self._lru_mode = (
|
|
152
|
+
(cache_strategy == CacheType.LRU) or (cache_strategy == "lru") or
|
|
153
|
+
(cache_strategy == CacheType.TTL) or (cache_strategy == "ttl")
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
if not self._lru_mode:
|
|
157
|
+
# Unbounded 以外のモードでは内部辞書の直接参照を使用しない場合があるが、
|
|
158
|
+
# 現状の設計では _cached_keys を通じて存在チェックを行っている
|
|
159
|
+
self._cached_keys = self._cache._cached_keys # type: ignore
|
|
160
|
+
else:
|
|
161
|
+
# LRU/TTL モードでは、データ保持自体が存在の証
|
|
162
|
+
self._cached_keys = self._data # type: ignore
|
|
163
|
+
|
|
164
|
+
self._all_loaded: bool = False # 全データ読み込み済みフラグ
|
|
165
|
+
|
|
166
|
+
# セキュリティ設定
|
|
167
|
+
self.strict_sql_validation = strict_sql_validation
|
|
168
|
+
self.allowed_sql_functions = set(allowed_sql_functions or [])
|
|
169
|
+
self.forbidden_sql_functions = set(forbidden_sql_functions or [])
|
|
170
|
+
self.max_clause_length = max_clause_length
|
|
171
|
+
|
|
172
|
+
# デフォルトで許可されるSQL関数
|
|
173
|
+
self._default_allowed_functions = {
|
|
174
|
+
"COUNT",
|
|
175
|
+
"SUM",
|
|
176
|
+
"AVG",
|
|
177
|
+
"MIN",
|
|
178
|
+
"MAX",
|
|
179
|
+
"ABS",
|
|
180
|
+
"UPPER",
|
|
181
|
+
"LOWER",
|
|
182
|
+
"LENGTH",
|
|
183
|
+
"ROUND",
|
|
184
|
+
"COALESCE",
|
|
185
|
+
"IFNULL",
|
|
186
|
+
"NULLIF",
|
|
187
|
+
"STRFTIME",
|
|
188
|
+
"DATE",
|
|
189
|
+
"TIME",
|
|
190
|
+
"DATETIME",
|
|
191
|
+
"JULIANDAY",
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
# トランザクション状態管理
|
|
195
|
+
self._in_transaction: bool = False # トランザクション中かどうか
|
|
196
|
+
self._transaction_depth: int = 0 # ネストレベル(警告用)
|
|
197
|
+
|
|
198
|
+
# 子インスタンスの追跡(リソース管理用)
|
|
199
|
+
self._child_instances = weakref.WeakSet() # WeakSetによる弱参照追跡(死んだ参照は自動的にクリーンアップ)
|
|
200
|
+
self._is_closed: bool = False # 接続が閉じられたか
|
|
201
|
+
self._parent_closed: bool = False # 親接続が閉じられたか
|
|
202
|
+
|
|
203
|
+
# 接続とロックの共有または新規作成
|
|
204
|
+
if _shared_connection is not None:
|
|
205
|
+
# 接続を共有(table()メソッドから呼ばれた場合)
|
|
206
|
+
self._connection: apsw.Connection = _shared_connection
|
|
207
|
+
self._lock = _shared_lock if _shared_lock is not None else threading.RLock()
|
|
208
|
+
self._is_connection_owner = False # 接続の所有者ではない
|
|
209
|
+
else:
|
|
210
|
+
# 新規接続を作成(通常の初期化)
|
|
211
|
+
try:
|
|
212
|
+
self._connection: apsw.Connection = apsw.Connection(db_path)
|
|
213
|
+
except apsw.Error as e:
|
|
214
|
+
raise NanaSQLiteConnectionError(f"Failed to connect to database: {e}") from e
|
|
215
|
+
self._lock = threading.RLock()
|
|
216
|
+
self._is_connection_owner = True # 接続の所有者
|
|
217
|
+
|
|
218
|
+
# 高速化設定(接続の所有者のみ)
|
|
219
|
+
if optimize:
|
|
220
|
+
self._apply_optimizations(cache_size_mb)
|
|
221
|
+
|
|
222
|
+
# テーブル作成
|
|
223
|
+
with self._lock:
|
|
224
|
+
self._connection.execute(f"""
|
|
225
|
+
CREATE TABLE IF NOT EXISTS {self._table} (
|
|
226
|
+
key TEXT PRIMARY KEY,
|
|
227
|
+
value TEXT
|
|
228
|
+
)
|
|
229
|
+
""")
|
|
230
|
+
|
|
231
|
+
# 一括ロード
|
|
232
|
+
if bulk_load:
|
|
233
|
+
self.load_all()
|
|
234
|
+
|
|
235
|
+
def _apply_optimizations(self, cache_size_mb: int = 64) -> None:
|
|
236
|
+
"""
|
|
237
|
+
APSWの高速化設定を適用
|
|
238
|
+
|
|
239
|
+
- WALモード: 書き込み並行性向上、30ms+ -> 1ms以下に改善
|
|
240
|
+
- synchronous=NORMAL: 安全性を保ちつつ高速化
|
|
241
|
+
- mmap: メモリマップドI/Oで読み込み高速化
|
|
242
|
+
- cache_size: SQLiteのメモリキャッシュ増加
|
|
243
|
+
- temp_store=MEMORY: 一時テーブルをメモリに
|
|
244
|
+
"""
|
|
245
|
+
cursor = self._connection.cursor()
|
|
246
|
+
|
|
247
|
+
# WALモード(Write-Ahead Logging)- 書き込み高速化の核心
|
|
248
|
+
cursor.execute("PRAGMA journal_mode = WAL")
|
|
249
|
+
|
|
250
|
+
# synchronous=NORMAL: WALモードでは安全かつ高速
|
|
251
|
+
cursor.execute("PRAGMA synchronous = NORMAL")
|
|
252
|
+
|
|
253
|
+
# メモリマップドI/O(256MB)- 読み込み高速化
|
|
254
|
+
cursor.execute("PRAGMA mmap_size = 268435456")
|
|
255
|
+
|
|
256
|
+
# キャッシュサイズ(負の値=KB単位)
|
|
257
|
+
cache_kb = cache_size_mb * 1024
|
|
258
|
+
cursor.execute(f"PRAGMA cache_size = -{cache_kb}")
|
|
259
|
+
|
|
260
|
+
# 一時テーブルをメモリに
|
|
261
|
+
cursor.execute("PRAGMA temp_store = MEMORY")
|
|
262
|
+
|
|
263
|
+
# ページサイズ最適化(新規DBのみ効果あり)
|
|
264
|
+
cursor.execute("PRAGMA page_size = 4096")
|
|
265
|
+
|
|
266
|
+
@staticmethod
|
|
267
|
+
def _sanitize_identifier(identifier: str) -> str:
|
|
268
|
+
"""
|
|
269
|
+
SQLiteの識別子(テーブル名、カラム名など)を検証
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
identifier: 検証する識別子
|
|
273
|
+
|
|
274
|
+
Returns:
|
|
275
|
+
検証済み識別子(ダブルクォートで囲まれる)
|
|
276
|
+
|
|
277
|
+
Raises:
|
|
278
|
+
NanaSQLiteValidationError: 識別子が無効な場合
|
|
279
|
+
|
|
280
|
+
Note:
|
|
281
|
+
SQLiteの識別子は以下をサポート:
|
|
282
|
+
- 英数字とアンダースコア
|
|
283
|
+
- 数字で開始しない
|
|
284
|
+
- SQLキーワードも引用符で囲めば使用可能
|
|
285
|
+
"""
|
|
286
|
+
if not identifier:
|
|
287
|
+
raise NanaSQLiteValidationError("Identifier cannot be empty")
|
|
288
|
+
|
|
289
|
+
# 基本的な検証: 英数字とアンダースコアのみ許可
|
|
290
|
+
if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", identifier):
|
|
291
|
+
raise NanaSQLiteValidationError(
|
|
292
|
+
f"Invalid identifier '{identifier}': must start with letter or underscore "
|
|
293
|
+
"and contain only alphanumeric characters and underscores"
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
# SQLiteではダブルクォートで囲むことで識別子をエスケープ
|
|
297
|
+
return f'"{identifier}"'
|
|
298
|
+
|
|
299
|
+
# ==================== Private Methods ====================
|
|
300
|
+
|
|
301
|
+
def __hash__(self):
|
|
302
|
+
# MutableMapping inhibits hashing by default because it's mutable.
|
|
303
|
+
# However, we need identity-based hashing to track instances in WeakSet.
|
|
304
|
+
# This is safe as long as we don't rely on content-based hashing in sets.
|
|
305
|
+
# NOTE: This technically violates the rule that a==b implies hash(a)==hash(b),
|
|
306
|
+
# because __eq__ implements content equivalence while __hash__ implements identity.
|
|
307
|
+
# This is an intentional design choice to support WeakSet management while providing
|
|
308
|
+
# convenient dict-like equality comparisons.
|
|
309
|
+
return id(self)
|
|
310
|
+
|
|
311
|
+
def __eq__(self, other):
|
|
312
|
+
"""
|
|
313
|
+
辞書のような等価性比較を実装
|
|
314
|
+
|
|
315
|
+
他のマッピング(dictやMutableMapping)との比較では内容ベースの比較を行い、
|
|
316
|
+
それ以外では同一性(is)での比較を行う。
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
other: 比較対象のオブジェクト
|
|
320
|
+
|
|
321
|
+
Returns:
|
|
322
|
+
bool: 等価な場合True、そうでない場合False
|
|
323
|
+
|
|
324
|
+
Raises:
|
|
325
|
+
NanaSQLiteClosedError: 接続が閉じられている場合
|
|
326
|
+
"""
|
|
327
|
+
if isinstance(other, (dict, MutableMapping)):
|
|
328
|
+
# Ensure the connection is open; propagate NanaSQLiteClosedError if not.
|
|
329
|
+
self._check_connection()
|
|
330
|
+
return dict(self.items()) == dict(other.items())
|
|
331
|
+
return self is other
|
|
332
|
+
|
|
333
|
+
def _check_connection(self) -> None:
|
|
334
|
+
"""
|
|
335
|
+
接続が有効かチェック
|
|
336
|
+
|
|
337
|
+
Raises:
|
|
338
|
+
NanaSQLiteClosedError: 接続が閉じられている、または親が閉じられている場合
|
|
339
|
+
"""
|
|
340
|
+
if self._is_closed:
|
|
341
|
+
raise NanaSQLiteClosedError(f"Database connection is closed (table: '{self._table}').")
|
|
342
|
+
if self._parent_closed:
|
|
343
|
+
raise NanaSQLiteClosedError(
|
|
344
|
+
f"Parent database connection is closed (table: '{self._table}'). "
|
|
345
|
+
"If you obtained this instance via .table(), ensure the primary "
|
|
346
|
+
"NanaSQLite instance remains open during usage."
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
def _validate_expression(
|
|
350
|
+
self,
|
|
351
|
+
expr: str | None,
|
|
352
|
+
strict: bool | None = None,
|
|
353
|
+
allowed: list[str] | None = None,
|
|
354
|
+
forbidden: list[str] | None = None,
|
|
355
|
+
override_allowed: bool = False,
|
|
356
|
+
context: Literal["order_by", "group_by", "where", "column"] | None = None,
|
|
357
|
+
) -> None:
|
|
358
|
+
"""
|
|
359
|
+
SQL表現(ORDER BY, GROUP BY, 列名等)を検証。
|
|
360
|
+
|
|
361
|
+
Args:
|
|
362
|
+
expr: 検証するSQL表現
|
|
363
|
+
strict: 強制停止モード。Noneの場合はインスタンス設定を使用。
|
|
364
|
+
allowed: 今回のクエリで追加/置換して許可する関数。
|
|
365
|
+
forbidden: 今回のクエリで明示的に禁止する関数。
|
|
366
|
+
override_allowed: Trueの場合、インスタンス許可設定を無視して今回のallowedのみ参照。
|
|
367
|
+
context: エラーメッセージのコンテキスト ("order_by", "group_by", "where", "column")
|
|
368
|
+
|
|
369
|
+
Raises:
|
|
370
|
+
NanaSQLiteValidationError: strict=True かつ不適切な表現の場合
|
|
371
|
+
UserWarning: strict=False かつ不適切な表現の場合(実行は許可)
|
|
372
|
+
"""
|
|
373
|
+
if not expr:
|
|
374
|
+
return
|
|
375
|
+
|
|
376
|
+
# 0. legacy check for SQL injection patterns
|
|
377
|
+
# test_security.py compatibility: raise ValueError for strictly dangerous patterns
|
|
378
|
+
# We use a combined message to satisfy both test_security.py ("Potentially dangerous...")
|
|
379
|
+
# and test_security_additions.py ("Invalid...")
|
|
380
|
+
warning_text = "Potentially dangerous SQL pattern"
|
|
381
|
+
|
|
382
|
+
context_labels = {
|
|
383
|
+
"order_by": "order_by clause",
|
|
384
|
+
"group_by": "group_by clause",
|
|
385
|
+
"where": "where clause",
|
|
386
|
+
"column": "column name",
|
|
387
|
+
}
|
|
388
|
+
label = context_labels.get(context)
|
|
389
|
+
|
|
390
|
+
# Standardize format: "Invalid [label]: [warning_text]" (or "Invalid: [warning_text]" if no label)
|
|
391
|
+
# This satisfies both legacy and new security tests.
|
|
392
|
+
if label:
|
|
393
|
+
full_msg = f"Invalid {label}: {warning_text}"
|
|
394
|
+
else:
|
|
395
|
+
full_msg = f"Invalid: {warning_text}"
|
|
396
|
+
|
|
397
|
+
dangerous_patterns = [
|
|
398
|
+
(r";", full_msg),
|
|
399
|
+
(r"--", full_msg),
|
|
400
|
+
(r"/\*", full_msg),
|
|
401
|
+
(r"\b(DROP|DELETE|UPDATE|INSERT|TRUNCATE|ALTER)\b", full_msg),
|
|
402
|
+
]
|
|
403
|
+
|
|
404
|
+
# 0.5. Fast character-set validation (ReDoS countermeasure)
|
|
405
|
+
if not fast_validate_sql_chars(str(expr)):
|
|
406
|
+
# If invalid characters are found, we apply strict or warning
|
|
407
|
+
# Note: This is a preventative layer.
|
|
408
|
+
# We use full_msg to maintain compatibility with existing tests expecting "Invalid [label]: ..."
|
|
409
|
+
msg = f"{full_msg} or invalid characters detected."
|
|
410
|
+
if strict or (strict is None and self.strict_sql_validation):
|
|
411
|
+
raise ValueError(msg)
|
|
412
|
+
else:
|
|
413
|
+
warnings.warn(msg, UserWarning, stacklevel=2)
|
|
414
|
+
|
|
415
|
+
for pattern, msg in dangerous_patterns:
|
|
416
|
+
if re.search(pattern, str(expr), re.IGNORECASE):
|
|
417
|
+
# Block highly dangerous patterns in strict mode, but only warn in non-strict
|
|
418
|
+
if strict or (strict is None and self.strict_sql_validation):
|
|
419
|
+
raise ValueError(msg)
|
|
420
|
+
else:
|
|
421
|
+
warnings.warn(msg, UserWarning, stacklevel=2)
|
|
422
|
+
|
|
423
|
+
# 1. 長さ制限 (ReDoS対策)
|
|
424
|
+
max_len = self.max_clause_length
|
|
425
|
+
if max_len and len(expr) > max_len:
|
|
426
|
+
msg = f"SQL expression exceeds maximum length of {max_len} characters."
|
|
427
|
+
if strict or (strict is None and self.strict_sql_validation):
|
|
428
|
+
raise NanaSQLiteValidationError(msg)
|
|
429
|
+
else:
|
|
430
|
+
warnings.warn(msg, UserWarning, stacklevel=2)
|
|
431
|
+
|
|
432
|
+
# 2. 禁止リストの整理 (メソッド指定を優先、なければインスタンス設定)
|
|
433
|
+
forbidden_list = set(forbidden) if forbidden is not None else self.forbidden_sql_functions
|
|
434
|
+
if "*" in forbidden_list:
|
|
435
|
+
msg = "All SQL functions are forbidden for this expression."
|
|
436
|
+
if strict or (strict is None and self.strict_sql_validation):
|
|
437
|
+
raise NanaSQLiteValidationError(msg)
|
|
438
|
+
else:
|
|
439
|
+
warnings.warn(msg, UserWarning, stacklevel=2)
|
|
440
|
+
return
|
|
441
|
+
|
|
442
|
+
# 3. 許可リストの整理
|
|
443
|
+
effective_allowed = set()
|
|
444
|
+
if not override_allowed:
|
|
445
|
+
effective_allowed.update(self._default_allowed_functions)
|
|
446
|
+
effective_allowed.update(self.allowed_sql_functions)
|
|
447
|
+
|
|
448
|
+
if allowed:
|
|
449
|
+
effective_allowed.update(allowed)
|
|
450
|
+
|
|
451
|
+
# 禁止リストに含まれるものは許可から削除
|
|
452
|
+
effective_allowed -= forbidden_list
|
|
453
|
+
|
|
454
|
+
# 4. 関数呼び出しの抽出
|
|
455
|
+
# 文字列リテラルやコメントをマスクした上で関数呼び出しを検索
|
|
456
|
+
# これにより、SELECT 'COUNT(' ... のようなパターンでの誤検知を防ぐ
|
|
457
|
+
sanitized_expr = sanitize_sql_for_function_scan(expr)
|
|
458
|
+
matches = re.findall(r"([a-zA-Z_][a-zA-Z0-9_]*)\s*\(", sanitized_expr)
|
|
459
|
+
|
|
460
|
+
for func in matches:
|
|
461
|
+
func_upper = func.upper()
|
|
462
|
+
|
|
463
|
+
# 明示的に禁止されている場合
|
|
464
|
+
if func_upper in forbidden_list:
|
|
465
|
+
msg = f"SQL function '{func_upper}' is explicitly forbidden."
|
|
466
|
+
if strict or (strict is None and self.strict_sql_validation):
|
|
467
|
+
raise NanaSQLiteValidationError(msg)
|
|
468
|
+
else:
|
|
469
|
+
warnings.warn(msg, UserWarning, stacklevel=2)
|
|
470
|
+
continue
|
|
471
|
+
|
|
472
|
+
# 許可リストにない場合
|
|
473
|
+
if func_upper not in effective_allowed:
|
|
474
|
+
msg = (
|
|
475
|
+
f"SQL function '{func_upper}' is not in the allowed list. "
|
|
476
|
+
"Use 'allowed_sql_functions' to permit it if you trust this function."
|
|
477
|
+
)
|
|
478
|
+
if strict or (strict is None and self.strict_sql_validation):
|
|
479
|
+
raise NanaSQLiteValidationError(msg)
|
|
480
|
+
else:
|
|
481
|
+
warnings.warn(msg, UserWarning, stacklevel=2)
|
|
482
|
+
|
|
483
|
+
def _mark_parent_closed(self) -> None:
|
|
484
|
+
"""
|
|
485
|
+
親インスタンスから呼ばれ、親が閉じられたことをマークする
|
|
486
|
+
"""
|
|
487
|
+
self._parent_closed = True
|
|
488
|
+
|
|
489
|
+
def _serialize(self, value: Any) -> bytes | str:
|
|
490
|
+
"""シリアライズ (JSON -> Encryption if enabled)"""
|
|
491
|
+
# Use fastest available JSON serializer
|
|
492
|
+
if HAS_ORJSON:
|
|
493
|
+
# orjson returns bytes
|
|
494
|
+
data = orjson.dumps(value)
|
|
495
|
+
json_str = None
|
|
496
|
+
else:
|
|
497
|
+
json_str = json.dumps(value, ensure_ascii=False)
|
|
498
|
+
data = json_str.encode("utf-8")
|
|
499
|
+
|
|
500
|
+
if self._fernet:
|
|
501
|
+
return self._fernet.encrypt(data)
|
|
502
|
+
|
|
503
|
+
if self._aead:
|
|
504
|
+
# Generate 12 bytes nonce
|
|
505
|
+
nonce = os.urandom(12)
|
|
506
|
+
ciphertext = self._aead.encrypt(nonce, data, None)
|
|
507
|
+
# Combine nonce + ciphertext
|
|
508
|
+
return nonce + ciphertext
|
|
509
|
+
|
|
510
|
+
# No encryption: store as TEXT for compatibility/perf (str)
|
|
511
|
+
if HAS_ORJSON:
|
|
512
|
+
# Decode once to keep DB storage as TEXT
|
|
513
|
+
return data.decode("utf-8")
|
|
514
|
+
return json_str
|
|
515
|
+
|
|
516
|
+
def _deserialize(self, value: bytes | str) -> Any:
|
|
517
|
+
"""デシリアライズ (Decryption if enabled -> JSON)"""
|
|
518
|
+
if self._fernet:
|
|
519
|
+
decoded = self._fernet.decrypt(value).decode("utf-8")
|
|
520
|
+
if HAS_ORJSON:
|
|
521
|
+
return orjson.loads(decoded)
|
|
522
|
+
return json.loads(decoded)
|
|
523
|
+
|
|
524
|
+
if self._aead:
|
|
525
|
+
if not isinstance(value, bytes):
|
|
526
|
+
# Fallback or manual check if stored as string accidentally
|
|
527
|
+
if HAS_ORJSON:
|
|
528
|
+
return orjson.loads(value)
|
|
529
|
+
return json.loads(value)
|
|
530
|
+
|
|
531
|
+
# Split nonce (12B) and ciphertext
|
|
532
|
+
nonce = value[:12]
|
|
533
|
+
ciphertext = value[12:]
|
|
534
|
+
decoded = self._aead.decrypt(nonce, ciphertext, None).decode("utf-8")
|
|
535
|
+
if HAS_ORJSON:
|
|
536
|
+
return orjson.loads(decoded)
|
|
537
|
+
return json.loads(decoded)
|
|
538
|
+
|
|
539
|
+
# No encryption path
|
|
540
|
+
if HAS_ORJSON:
|
|
541
|
+
return orjson.loads(value)
|
|
542
|
+
return json.loads(value)
|
|
543
|
+
|
|
544
|
+
def _write_to_db(self, key: str, value: Any) -> None:
|
|
545
|
+
"""即時書き込み: SQLiteに値を保存"""
|
|
546
|
+
serialized = self._serialize(value)
|
|
547
|
+
with self._lock:
|
|
548
|
+
self._connection.execute(
|
|
549
|
+
f"INSERT OR REPLACE INTO {self._table} (key, value) VALUES (?, ?)", # nosec
|
|
550
|
+
(key, serialized),
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
def _read_from_db(self, key: str) -> Any | None:
|
|
554
|
+
"""SQLiteから値を読み込み"""
|
|
555
|
+
with self._lock:
|
|
556
|
+
cursor = self._connection.execute(
|
|
557
|
+
f"SELECT value FROM {self._table} WHERE key = ?", # nosec
|
|
558
|
+
(key,),
|
|
559
|
+
)
|
|
560
|
+
row = cursor.fetchone()
|
|
561
|
+
if row is None:
|
|
562
|
+
return None
|
|
563
|
+
return self._deserialize(row[0])
|
|
564
|
+
|
|
565
|
+
def _delete_from_db(self, key: str) -> None:
|
|
566
|
+
"""SQLiteから値を削除"""
|
|
567
|
+
with self._lock:
|
|
568
|
+
self._connection.execute(
|
|
569
|
+
f"DELETE FROM {self._table} WHERE key = ?", # nosec
|
|
570
|
+
(key,),
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
def _get_all_keys_from_db(self) -> list:
|
|
574
|
+
"""SQLiteから全キーを取得"""
|
|
575
|
+
with self._lock:
|
|
576
|
+
cursor = self._connection.execute(
|
|
577
|
+
f"SELECT key FROM {self._table}" # nosec
|
|
578
|
+
)
|
|
579
|
+
return [row[0] for row in cursor]
|
|
580
|
+
|
|
581
|
+
def _ensure_cached(self, key: str) -> bool:
|
|
582
|
+
"""
|
|
583
|
+
キーがキャッシュにない場合、DBから読み込む(遅延ロード)
|
|
584
|
+
Returns: キーが存在するかどうか
|
|
585
|
+
"""
|
|
586
|
+
# FAST PATH for default Unbounded mode
|
|
587
|
+
if not self._lru_mode:
|
|
588
|
+
if key in self._cached_keys:
|
|
589
|
+
return key in self._data
|
|
590
|
+
else:
|
|
591
|
+
if key in self._data:
|
|
592
|
+
return True
|
|
593
|
+
|
|
594
|
+
# DBから読み込み
|
|
595
|
+
value = self._read_from_db(key)
|
|
596
|
+
|
|
597
|
+
if value is not None:
|
|
598
|
+
if self._lru_mode or (hasattr(self._cache, "_max_size") and self._cache._max_size):
|
|
599
|
+
self._cache.set(key, value)
|
|
600
|
+
else:
|
|
601
|
+
self._data[key] = value
|
|
602
|
+
self._cached_keys.add(key)
|
|
603
|
+
return True
|
|
604
|
+
|
|
605
|
+
# Value is None (not in DB)
|
|
606
|
+
if not self._lru_mode:
|
|
607
|
+
self._cached_keys.add(key)
|
|
608
|
+
return False
|
|
609
|
+
|
|
610
|
+
# ==================== Dict Interface ====================
|
|
611
|
+
|
|
612
|
+
def __getitem__(self, key: str) -> Any:
|
|
613
|
+
"""dict[key] - 遅延ロード後、メモリから取得"""
|
|
614
|
+
if self._ensure_cached(key):
|
|
615
|
+
# LRU updates order even on __getitem__
|
|
616
|
+
if self._lru_mode:
|
|
617
|
+
return self._cache.get(key)
|
|
618
|
+
return self._data[key]
|
|
619
|
+
raise KeyError(key)
|
|
620
|
+
|
|
621
|
+
def __setitem__(self, key: str, value: Any) -> None:
|
|
622
|
+
"""dict[key] = value - 即時書き込み + メモリ更新"""
|
|
623
|
+
self._check_connection()
|
|
624
|
+
# メモリ更新
|
|
625
|
+
if self._lru_mode or (hasattr(self._cache, "_max_size") and self._cache._max_size):
|
|
626
|
+
self._cache.set(key, value)
|
|
627
|
+
else:
|
|
628
|
+
self._data[key] = value
|
|
629
|
+
self._cached_keys.add(key)
|
|
630
|
+
# 即時書き込み
|
|
631
|
+
self._write_to_db(key, value)
|
|
632
|
+
|
|
633
|
+
def __delitem__(self, key: str) -> None:
|
|
634
|
+
"""del dict[key] - 即時削除"""
|
|
635
|
+
if not self._ensure_cached(key):
|
|
636
|
+
raise KeyError(key)
|
|
637
|
+
# メモリから削除
|
|
638
|
+
if self._lru_mode:
|
|
639
|
+
self._cache.delete(key)
|
|
640
|
+
else:
|
|
641
|
+
self._data.pop(key, None)
|
|
642
|
+
self._cached_keys.discard(key)
|
|
643
|
+
|
|
644
|
+
# DBから削除
|
|
645
|
+
self._delete_from_db(key)
|
|
646
|
+
|
|
647
|
+
def __contains__(self, key: str) -> bool:
|
|
648
|
+
"""
|
|
649
|
+
key in dict - キーの存在確認
|
|
650
|
+
|
|
651
|
+
キャッシュにある場合はO(1)、ない場合は軽量なEXISTSクエリを使用。
|
|
652
|
+
存在確認のみの場合、value全体を読み込まないため高速。
|
|
653
|
+
"""
|
|
654
|
+
# FAST PATH
|
|
655
|
+
if key in self._cached_keys:
|
|
656
|
+
return key in self._data
|
|
657
|
+
|
|
658
|
+
# 軽量な存在確認クエリ(valueを読み込まない)
|
|
659
|
+
with self._lock:
|
|
660
|
+
cursor = self._connection.execute(
|
|
661
|
+
f"SELECT 1 FROM {self._table} WHERE key = ? LIMIT 1", # nosec
|
|
662
|
+
(key,), # nosec
|
|
663
|
+
)
|
|
664
|
+
exists = cursor.fetchone() is not None
|
|
665
|
+
|
|
666
|
+
if exists:
|
|
667
|
+
# 存在をマークするが、値は読み込まない(次回アクセス時に遅延ロード)
|
|
668
|
+
if self._lru_mode:
|
|
669
|
+
self._cache.mark_cached(key)
|
|
670
|
+
else:
|
|
671
|
+
self._cached_keys.add(key)
|
|
672
|
+
return True
|
|
673
|
+
else:
|
|
674
|
+
# 存在しないこともキャッシュ
|
|
675
|
+
if not self._lru_mode:
|
|
676
|
+
self._cached_keys.add(key)
|
|
677
|
+
return False
|
|
678
|
+
|
|
679
|
+
def __len__(self) -> int:
|
|
680
|
+
"""len(dict) - DBの実際の件数を返す"""
|
|
681
|
+
with self._lock:
|
|
682
|
+
cursor = self._connection.execute(
|
|
683
|
+
f"SELECT COUNT(*) FROM {self._table}" # nosec
|
|
684
|
+
)
|
|
685
|
+
return cursor.fetchone()[0]
|
|
686
|
+
|
|
687
|
+
def __iter__(self) -> Iterator[str]:
|
|
688
|
+
"""for key in dict"""
|
|
689
|
+
return iter(self.keys())
|
|
690
|
+
|
|
691
|
+
def __repr__(self) -> str:
|
|
692
|
+
return f"NanaSQLite({self._db_path!r}, table={self._table!r}, cached={self._cache.size})"
|
|
693
|
+
|
|
694
|
+
# ==================== Dict Methods ====================
|
|
695
|
+
|
|
696
|
+
def keys(self) -> list:
|
|
697
|
+
"""全キーを取得(DBから)"""
|
|
698
|
+
return self._get_all_keys_from_db()
|
|
699
|
+
|
|
700
|
+
def values(self) -> list:
|
|
701
|
+
"""全値を取得(一括ロードしてからメモリから)"""
|
|
702
|
+
self._check_connection()
|
|
703
|
+
self.load_all()
|
|
704
|
+
return list(self._cache.get_data().values())
|
|
705
|
+
|
|
706
|
+
def items(self) -> list:
|
|
707
|
+
"""全アイテムを取得(一括ロードしてからメモリから)"""
|
|
708
|
+
self.load_all()
|
|
709
|
+
return list(self._cache.get_data().items())
|
|
710
|
+
|
|
711
|
+
def get(self, key: str, default: Any = None) -> Any:
|
|
712
|
+
"""dict.get(key, default)"""
|
|
713
|
+
if self._ensure_cached(key):
|
|
714
|
+
if self._lru_mode:
|
|
715
|
+
return self._cache.get(key)
|
|
716
|
+
return self._data[key]
|
|
717
|
+
return default
|
|
718
|
+
|
|
719
|
+
def get_fresh(self, key: str, default: Any = None) -> Any:
|
|
720
|
+
"""
|
|
721
|
+
DBから直接読み込み、キャッシュを更新して値を返す
|
|
722
|
+
|
|
723
|
+
キャッシュをバイパスしてDBから最新の値を取得する。
|
|
724
|
+
`execute()`でDBを直接変更した後などに使用。
|
|
725
|
+
|
|
726
|
+
通常の`get()`よりオーバーヘッドがあるため、
|
|
727
|
+
キャッシュとDBの不整合が想定される場合のみ使用推奨。
|
|
728
|
+
|
|
729
|
+
Args:
|
|
730
|
+
key: 取得するキー
|
|
731
|
+
default: キーが存在しない場合のデフォルト値
|
|
732
|
+
|
|
733
|
+
Returns:
|
|
734
|
+
DBから取得した最新の値(存在しない場合はdefault)
|
|
735
|
+
|
|
736
|
+
Example:
|
|
737
|
+
>>> db.execute("UPDATE data SET value = ? WHERE key = ?", ('"new"', "key"))
|
|
738
|
+
>>> value = db.get_fresh("key") # DBから最新値を取得
|
|
739
|
+
"""
|
|
740
|
+
# DBから直接読み込み
|
|
741
|
+
value = self._read_from_db(key)
|
|
742
|
+
|
|
743
|
+
if value is not None:
|
|
744
|
+
# キャッシュを更新
|
|
745
|
+
if self._lru_mode:
|
|
746
|
+
self._cache.set(key, value)
|
|
747
|
+
else:
|
|
748
|
+
self._data[key] = value
|
|
749
|
+
self._cached_keys.add(key)
|
|
750
|
+
return value
|
|
751
|
+
else:
|
|
752
|
+
# 存在しない場合はキャッシュからも削除
|
|
753
|
+
if self._lru_mode:
|
|
754
|
+
self._cache.delete(key)
|
|
755
|
+
else:
|
|
756
|
+
self._data.pop(key, None)
|
|
757
|
+
self._cached_keys.add(key) # 「存在しない」ことをマーク
|
|
758
|
+
return default
|
|
759
|
+
|
|
760
|
+
def batch_get(self, keys: list[str]) -> dict[str, Any]:
|
|
761
|
+
"""
|
|
762
|
+
複数のキーを一度に取得(効率的な一括ロード)
|
|
763
|
+
|
|
764
|
+
1回の `SELECT IN (...)` クエリで複数のキーをDBから取得する。
|
|
765
|
+
取得した値は自動的にキャッシュに保存される。
|
|
766
|
+
|
|
767
|
+
Args:
|
|
768
|
+
keys: 取得するキーのリスト
|
|
769
|
+
|
|
770
|
+
Returns:
|
|
771
|
+
取得に成功したキーと値の dict
|
|
772
|
+
|
|
773
|
+
Example:
|
|
774
|
+
>>> results = db.batch_get(["user1", "user2", "user3"])
|
|
775
|
+
>>> print(results) # {"user1": {...}, "user2": {...}}
|
|
776
|
+
"""
|
|
777
|
+
if not keys:
|
|
778
|
+
return {}
|
|
779
|
+
|
|
780
|
+
results = {}
|
|
781
|
+
missing_keys = []
|
|
782
|
+
|
|
783
|
+
# 1. キャッシュから取得可能なものをチェック
|
|
784
|
+
for key in keys:
|
|
785
|
+
if self._cache.is_cached(key):
|
|
786
|
+
val = self._cache.get(key)
|
|
787
|
+
if val is not None:
|
|
788
|
+
results[key] = val
|
|
789
|
+
else:
|
|
790
|
+
missing_keys.append(key)
|
|
791
|
+
|
|
792
|
+
if not missing_keys:
|
|
793
|
+
return results
|
|
794
|
+
|
|
795
|
+
# 2. DBから足りない分を一括取得
|
|
796
|
+
placeholders = ",".join(["?"] * len(missing_keys))
|
|
797
|
+
sql = f"SELECT key, value FROM {self._table} WHERE key IN ({placeholders})" # nosec
|
|
798
|
+
|
|
799
|
+
with self._lock:
|
|
800
|
+
cursor = self._connection.execute(sql, tuple(missing_keys))
|
|
801
|
+
for key, val_str in cursor:
|
|
802
|
+
value = self._deserialize(val_str)
|
|
803
|
+
self._cache.set(key, value)
|
|
804
|
+
results[key] = value
|
|
805
|
+
|
|
806
|
+
# 3. DBにも存在しなかったキーを「存在しない」としてキャッシュ
|
|
807
|
+
found_keys = set(results.keys())
|
|
808
|
+
for key in missing_keys:
|
|
809
|
+
if key not in found_keys:
|
|
810
|
+
self._cache.mark_cached(key)
|
|
811
|
+
|
|
812
|
+
return results
|
|
813
|
+
|
|
814
|
+
def pop(self, key: str, *args) -> Any:
|
|
815
|
+
"""dict.pop(key[, default])"""
|
|
816
|
+
self._check_connection()
|
|
817
|
+
if self._ensure_cached(key):
|
|
818
|
+
value = self._cache.get(key)
|
|
819
|
+
self._cache.delete(key)
|
|
820
|
+
self._delete_from_db(key)
|
|
821
|
+
return value
|
|
822
|
+
if args:
|
|
823
|
+
return args[0]
|
|
824
|
+
raise KeyError(key)
|
|
825
|
+
|
|
826
|
+
def update(self, mapping: dict = None, **kwargs) -> None:
|
|
827
|
+
"""dict.update(mapping) - 一括更新"""
|
|
828
|
+
if mapping:
|
|
829
|
+
for key, value in mapping.items():
|
|
830
|
+
self[key] = value
|
|
831
|
+
for key, value in kwargs.items():
|
|
832
|
+
self[key] = value
|
|
833
|
+
|
|
834
|
+
def clear(self) -> None:
|
|
835
|
+
"""dict.clear() - 全削除"""
|
|
836
|
+
self._cache.clear()
|
|
837
|
+
self._all_loaded = False
|
|
838
|
+
with self._lock:
|
|
839
|
+
self._connection.execute(f"DELETE FROM {self._table}") # nosec
|
|
840
|
+
|
|
841
|
+
def setdefault(self, key: str, default: Any = None) -> Any:
|
|
842
|
+
"""dict.setdefault(key, default)"""
|
|
843
|
+
if self._ensure_cached(key):
|
|
844
|
+
return self._cache.get(key)
|
|
845
|
+
self[key] = default
|
|
846
|
+
return default
|
|
847
|
+
|
|
848
|
+
# ==================== Special Methods ====================
|
|
849
|
+
|
|
850
|
+
def load_all(self) -> None:
|
|
851
|
+
"""一括読み込み: 全データをメモリに展開"""
|
|
852
|
+
if self._all_loaded:
|
|
853
|
+
return
|
|
854
|
+
|
|
855
|
+
with self._lock:
|
|
856
|
+
cursor = self._connection.execute(
|
|
857
|
+
f"SELECT key, value FROM {self._table}" # nosec
|
|
858
|
+
)
|
|
859
|
+
rows = list(cursor) # ロック内でフェッチ
|
|
860
|
+
|
|
861
|
+
for key, value in rows:
|
|
862
|
+
self._cache.set(key, self._deserialize(value))
|
|
863
|
+
|
|
864
|
+
self._all_loaded = True
|
|
865
|
+
|
|
866
|
+
def refresh(self, key: str = None) -> None:
|
|
867
|
+
"""
|
|
868
|
+
キャッシュを更新(DBから再読み込み)
|
|
869
|
+
|
|
870
|
+
Args:
|
|
871
|
+
key: 特定のキーのみ更新。Noneの場合は全キャッシュをクリアして再読み込み
|
|
872
|
+
"""
|
|
873
|
+
if key is not None:
|
|
874
|
+
# FAST PATH for performance
|
|
875
|
+
if not self._lru_mode:
|
|
876
|
+
self._data.pop(key, None)
|
|
877
|
+
self._cached_keys.discard(key)
|
|
878
|
+
else:
|
|
879
|
+
self._cache.invalidate(key)
|
|
880
|
+
self._ensure_cached(key)
|
|
881
|
+
else:
|
|
882
|
+
self.clear_cache()
|
|
883
|
+
|
|
884
|
+
def is_cached(self, key: str) -> bool:
|
|
885
|
+
"""キーがキャッシュ済みかどうか"""
|
|
886
|
+
# FAST PATH for performance
|
|
887
|
+
if not self._lru_mode:
|
|
888
|
+
return key in self._cached_keys
|
|
889
|
+
return self._cache.is_cached(key)
|
|
890
|
+
|
|
891
|
+
def batch_update(self, mapping: dict[str, Any]) -> None:
|
|
892
|
+
"""
|
|
893
|
+
一括書き込み(トランザクション + executemany使用で超高速)
|
|
894
|
+
|
|
895
|
+
大量のデータを一度に書き込む場合、通常のupdateより10-100倍高速。
|
|
896
|
+
v1.0.3rc5でexecutemanyによる最適化を追加。
|
|
897
|
+
|
|
898
|
+
Args:
|
|
899
|
+
mapping: 書き込むキーと値のdict
|
|
900
|
+
|
|
901
|
+
Returns:
|
|
902
|
+
None
|
|
903
|
+
|
|
904
|
+
Example:
|
|
905
|
+
>>> db.batch_update({"key1": "value1", "key2": "value2", ...})
|
|
906
|
+
"""
|
|
907
|
+
if not mapping:
|
|
908
|
+
return # 空の場合は何もしない
|
|
909
|
+
|
|
910
|
+
cursor = self._connection.cursor()
|
|
911
|
+
cursor.execute("BEGIN IMMEDIATE")
|
|
912
|
+
try:
|
|
913
|
+
# 事前にシリアライズしてexecutemany用のタプルリストを作成
|
|
914
|
+
params = [(key, self._serialize(value)) for key, value in mapping.items()]
|
|
915
|
+
cursor.executemany(
|
|
916
|
+
f"INSERT OR REPLACE INTO {self._table} (key, value) VALUES (?, ?)", # nosec
|
|
917
|
+
params,
|
|
918
|
+
)
|
|
919
|
+
# キャッシュ更新
|
|
920
|
+
for key, value in mapping.items():
|
|
921
|
+
if self._lru_mode:
|
|
922
|
+
self._cache.set(key, value)
|
|
923
|
+
else:
|
|
924
|
+
self._data[key] = value
|
|
925
|
+
self._cached_keys.add(key)
|
|
926
|
+
cursor.execute("COMMIT")
|
|
927
|
+
except Exception:
|
|
928
|
+
cursor.execute("ROLLBACK")
|
|
929
|
+
raise
|
|
930
|
+
|
|
931
|
+
def batch_delete(self, keys: list[str]) -> None:
|
|
932
|
+
"""
|
|
933
|
+
一括削除(トランザクション + executemany使用で高速)
|
|
934
|
+
|
|
935
|
+
v1.0.3rc5でexecutemanyによる最適化を追加。
|
|
936
|
+
|
|
937
|
+
Args:
|
|
938
|
+
keys: 削除するキーのリスト
|
|
939
|
+
|
|
940
|
+
Returns:
|
|
941
|
+
None
|
|
942
|
+
"""
|
|
943
|
+
self._check_connection()
|
|
944
|
+
if not keys:
|
|
945
|
+
return # 空の場合は何もしない
|
|
946
|
+
|
|
947
|
+
cursor = self._connection.cursor()
|
|
948
|
+
cursor.execute("BEGIN IMMEDIATE")
|
|
949
|
+
try:
|
|
950
|
+
# executemany用のタプルリストを作成
|
|
951
|
+
params = [(key,) for key in keys]
|
|
952
|
+
cursor.executemany(
|
|
953
|
+
f"DELETE FROM {self._table} WHERE key = ?", # nosec
|
|
954
|
+
params,
|
|
955
|
+
)
|
|
956
|
+
# キャッシュ更新
|
|
957
|
+
for key in keys:
|
|
958
|
+
if self._lru_mode:
|
|
959
|
+
self._cache.delete(key)
|
|
960
|
+
else:
|
|
961
|
+
self._data.pop(key, None)
|
|
962
|
+
self._cached_keys.discard(key)
|
|
963
|
+
cursor.execute("COMMIT")
|
|
964
|
+
except Exception:
|
|
965
|
+
cursor.execute("ROLLBACK")
|
|
966
|
+
raise
|
|
967
|
+
|
|
968
|
+
def to_dict(self) -> dict:
|
|
969
|
+
"""全データをPython dictとして取得"""
|
|
970
|
+
self._check_connection()
|
|
971
|
+
self.load_all()
|
|
972
|
+
return dict(self._data)
|
|
973
|
+
|
|
974
|
+
def copy(self) -> dict:
|
|
975
|
+
"""浅いコピーを作成(標準dictを返す)"""
|
|
976
|
+
return self.to_dict()
|
|
977
|
+
|
|
978
|
+
def clear_cache(self) -> None:
|
|
979
|
+
"""
|
|
980
|
+
メモリキャッシュをクリア
|
|
981
|
+
|
|
982
|
+
DBのデータは削除せず、メモリ上のキャッシュのみ破棄します。
|
|
983
|
+
"""
|
|
984
|
+
self._cache.clear()
|
|
985
|
+
self._all_loaded = False
|
|
986
|
+
|
|
987
|
+
def _delete_from_db_on_expire(self, key: str) -> None:
|
|
988
|
+
"""有効期限切れ時にDBからデータを削除 (内部用)"""
|
|
989
|
+
with self._lock:
|
|
990
|
+
if self._is_closed:
|
|
991
|
+
return
|
|
992
|
+
try:
|
|
993
|
+
self._connection.execute(f'DELETE FROM "{self._table}" WHERE key = ?', (key,)) # nosec
|
|
994
|
+
except apsw.Error as e:
|
|
995
|
+
logger.error(f"SQL error during background expiration for key '{key}': {e}")
|
|
996
|
+
|
|
997
|
+
def close(self) -> None:
|
|
998
|
+
"""
|
|
999
|
+
データベース接続を閉じる
|
|
1000
|
+
|
|
1001
|
+
注意: table()メソッドで作成されたインスタンスは接続を共有しているため、
|
|
1002
|
+
接続の所有者(最初に作成されたインスタンス)のみが接続を閉じます。
|
|
1003
|
+
|
|
1004
|
+
Raises:
|
|
1005
|
+
NanaSQLiteTransactionError: トランザクション中にクローズを試みた場合
|
|
1006
|
+
"""
|
|
1007
|
+
if self._is_closed:
|
|
1008
|
+
return # 既に閉じられている場合は何もしない
|
|
1009
|
+
|
|
1010
|
+
if self._in_transaction:
|
|
1011
|
+
raise NanaSQLiteTransactionError(
|
|
1012
|
+
"Cannot close connection while transaction is in progress. Please commit or rollback first."
|
|
1013
|
+
)
|
|
1014
|
+
|
|
1015
|
+
# 子インスタンスに通知
|
|
1016
|
+
for child in self._child_instances:
|
|
1017
|
+
child._mark_parent_closed()
|
|
1018
|
+
|
|
1019
|
+
self._child_instances.clear()
|
|
1020
|
+
self._is_closed = True
|
|
1021
|
+
|
|
1022
|
+
if self._is_connection_owner:
|
|
1023
|
+
try:
|
|
1024
|
+
self._connection.close()
|
|
1025
|
+
except apsw.Error as e:
|
|
1026
|
+
# 接続クローズの失敗は警告に留める
|
|
1027
|
+
import warnings
|
|
1028
|
+
|
|
1029
|
+
warnings.warn(f"Failed to close database connection: {e}", stacklevel=2)
|
|
1030
|
+
|
|
1031
|
+
def __enter__(self):
|
|
1032
|
+
"""コンテキストマネージャ対応"""
|
|
1033
|
+
return self
|
|
1034
|
+
|
|
1035
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
1036
|
+
"""コンテキストマネージャ対応"""
|
|
1037
|
+
self.close()
|
|
1038
|
+
return False
|
|
1039
|
+
|
|
1040
|
+
# ==================== Pydantic Support ====================
|
|
1041
|
+
|
|
1042
|
+
def set_model(self, key: str, model: Any) -> None:
|
|
1043
|
+
"""
|
|
1044
|
+
Pydanticモデルを保存
|
|
1045
|
+
|
|
1046
|
+
Pydanticモデル(BaseModelを継承したクラス)をシリアライズして保存。
|
|
1047
|
+
model_dump()メソッドを使用してdictに変換し、モデルのクラス情報も保存。
|
|
1048
|
+
|
|
1049
|
+
Args:
|
|
1050
|
+
key: 保存するキー
|
|
1051
|
+
model: Pydanticモデルのインスタンス
|
|
1052
|
+
|
|
1053
|
+
Example:
|
|
1054
|
+
>>> from pydantic import BaseModel
|
|
1055
|
+
>>> class User(BaseModel):
|
|
1056
|
+
... name: str
|
|
1057
|
+
... age: int
|
|
1058
|
+
>>> user = User(name="Nana", age=20)
|
|
1059
|
+
>>> db.set_model("user", user)
|
|
1060
|
+
"""
|
|
1061
|
+
try:
|
|
1062
|
+
# Pydanticモデルかチェック (model_dump メソッドの存在で判定)
|
|
1063
|
+
if hasattr(model, "model_dump"):
|
|
1064
|
+
data = {
|
|
1065
|
+
"__pydantic_model__": f"{model.__class__.__module__}.{model.__class__.__qualname__}",
|
|
1066
|
+
"__pydantic_data__": model.model_dump(),
|
|
1067
|
+
}
|
|
1068
|
+
self[key] = data
|
|
1069
|
+
else:
|
|
1070
|
+
raise TypeError(f"Object of type {type(model)} is not a Pydantic model")
|
|
1071
|
+
except Exception as e:
|
|
1072
|
+
raise TypeError(f"Failed to serialize Pydantic model: {e}")
|
|
1073
|
+
|
|
1074
|
+
def get_model(self, key: str, model_class: type = None) -> Any:
|
|
1075
|
+
"""
|
|
1076
|
+
Pydanticモデルを取得
|
|
1077
|
+
|
|
1078
|
+
保存されたPydanticモデルをデシリアライズして復元。
|
|
1079
|
+
model_classが指定されていない場合は、保存時のクラス情報を使用。
|
|
1080
|
+
|
|
1081
|
+
Args:
|
|
1082
|
+
key: 取得するキー
|
|
1083
|
+
model_class: Pydanticモデルのクラス(Noneの場合は自動検出を試みる)
|
|
1084
|
+
|
|
1085
|
+
Returns:
|
|
1086
|
+
Pydanticモデルのインスタンス
|
|
1087
|
+
|
|
1088
|
+
Example:
|
|
1089
|
+
>>> user = db.get_model("user", User)
|
|
1090
|
+
>>> print(user.name) # "Nana"
|
|
1091
|
+
"""
|
|
1092
|
+
data = self[key]
|
|
1093
|
+
|
|
1094
|
+
if isinstance(data, dict) and "__pydantic_model__" in data and "__pydantic_data__" in data:
|
|
1095
|
+
if model_class is None:
|
|
1096
|
+
# 自動検出は複雑なため、model_classを推奨
|
|
1097
|
+
raise ValueError("model_class must be provided for get_model()")
|
|
1098
|
+
|
|
1099
|
+
# Pydanticモデルとして復元
|
|
1100
|
+
try:
|
|
1101
|
+
return model_class(**data["__pydantic_data__"])
|
|
1102
|
+
except Exception as e:
|
|
1103
|
+
raise ValueError(f"Failed to deserialize Pydantic model: {e}")
|
|
1104
|
+
elif model_class is not None:
|
|
1105
|
+
# 通常のdictをPydanticモデルに変換
|
|
1106
|
+
try:
|
|
1107
|
+
return model_class(**data)
|
|
1108
|
+
except Exception as e:
|
|
1109
|
+
raise ValueError(f"Failed to create Pydantic model from data: {e}")
|
|
1110
|
+
else:
|
|
1111
|
+
raise ValueError("Data is not a Pydantic model and no model_class provided")
|
|
1112
|
+
|
|
1113
|
+
# ==================== Direct SQL Execution ====================
|
|
1114
|
+
|
|
1115
|
+
def execute(self, sql: str, parameters: tuple | None = None) -> apsw.Cursor:
|
|
1116
|
+
"""
|
|
1117
|
+
SQLを直接実行
|
|
1118
|
+
|
|
1119
|
+
任意のSQL文を実行できる。SELECT、INSERT、UPDATE、DELETEなど。
|
|
1120
|
+
パラメータバインディングをサポート(SQLインジェクション対策)。
|
|
1121
|
+
|
|
1122
|
+
.. warning::
|
|
1123
|
+
このメソッドで直接デフォルトテーブル(data)を操作した場合、
|
|
1124
|
+
内部キャッシュ(_data)と不整合が発生する可能性があります。
|
|
1125
|
+
キャッシュを更新するには `refresh()` を呼び出してください。
|
|
1126
|
+
|
|
1127
|
+
Args:
|
|
1128
|
+
sql: 実行するSQL文
|
|
1129
|
+
parameters: SQLのパラメータ(?プレースホルダー用)
|
|
1130
|
+
|
|
1131
|
+
Returns:
|
|
1132
|
+
APSWのCursorオブジェクト(結果の取得に使用)
|
|
1133
|
+
|
|
1134
|
+
Raises:
|
|
1135
|
+
NanaSQLiteConnectionError: 接続が閉じられている場合
|
|
1136
|
+
NanaSQLiteDatabaseError: SQL実行エラー
|
|
1137
|
+
|
|
1138
|
+
Example:
|
|
1139
|
+
>>> cursor = db.execute("SELECT * FROM data WHERE key LIKE ?", ("user%",))
|
|
1140
|
+
>>> for row in cursor:
|
|
1141
|
+
... print(row)
|
|
1142
|
+
|
|
1143
|
+
# キャッシュ更新が必要な場合:
|
|
1144
|
+
>>> db.execute("UPDATE data SET value = ? WHERE key = ?", ('"new"', "key"))
|
|
1145
|
+
>>> db.refresh("key") # キャッシュを更新
|
|
1146
|
+
"""
|
|
1147
|
+
self._check_connection()
|
|
1148
|
+
|
|
1149
|
+
try:
|
|
1150
|
+
with self._lock:
|
|
1151
|
+
if parameters is None:
|
|
1152
|
+
return self._connection.execute(sql)
|
|
1153
|
+
else:
|
|
1154
|
+
return self._connection.execute(sql, parameters)
|
|
1155
|
+
except apsw.Error as e:
|
|
1156
|
+
raise NanaSQLiteDatabaseError(f"Failed to execute SQL: {e}", original_error=e) from e
|
|
1157
|
+
|
|
1158
|
+
def execute_many(self, sql: str, parameters_list: list[tuple]) -> None:
|
|
1159
|
+
"""
|
|
1160
|
+
SQLを複数のパラメータで一括実行
|
|
1161
|
+
|
|
1162
|
+
同じSQL文を複数のパラメータセットで実行(トランザクション使用)。
|
|
1163
|
+
大量のINSERTやUPDATEを高速に実行できる。
|
|
1164
|
+
|
|
1165
|
+
Args:
|
|
1166
|
+
sql: 実行するSQL文
|
|
1167
|
+
parameters_list: パラメータのリスト
|
|
1168
|
+
|
|
1169
|
+
Example:
|
|
1170
|
+
>>> db.execute_many(
|
|
1171
|
+
... "INSERT OR REPLACE INTO custom (id, name) VALUES (?, ?)",
|
|
1172
|
+
... [(1, "Alice"), (2, "Bob"), (3, "Charlie")]
|
|
1173
|
+
... )
|
|
1174
|
+
"""
|
|
1175
|
+
with self._lock:
|
|
1176
|
+
cursor = self._connection.cursor()
|
|
1177
|
+
cursor.execute("BEGIN IMMEDIATE")
|
|
1178
|
+
try:
|
|
1179
|
+
for parameters in parameters_list:
|
|
1180
|
+
cursor.execute(sql, parameters)
|
|
1181
|
+
cursor.execute("COMMIT")
|
|
1182
|
+
except apsw.Error:
|
|
1183
|
+
cursor.execute("ROLLBACK")
|
|
1184
|
+
raise
|
|
1185
|
+
|
|
1186
|
+
def fetch_one(self, sql: str, parameters: tuple = None) -> tuple | None:
|
|
1187
|
+
"""
|
|
1188
|
+
SQLを実行して1行取得
|
|
1189
|
+
|
|
1190
|
+
Args:
|
|
1191
|
+
sql: 実行するSQL文
|
|
1192
|
+
parameters: SQLのパラメータ
|
|
1193
|
+
|
|
1194
|
+
Returns:
|
|
1195
|
+
1行の結果(tuple)、結果がない場合はNone
|
|
1196
|
+
|
|
1197
|
+
Example:
|
|
1198
|
+
>>> row = db.fetch_one("SELECT value FROM data WHERE key = ?", ("user",))
|
|
1199
|
+
>>> print(row[0])
|
|
1200
|
+
"""
|
|
1201
|
+
cursor = self.execute(sql, parameters)
|
|
1202
|
+
return cursor.fetchone()
|
|
1203
|
+
|
|
1204
|
+
def fetch_all(self, sql: str, parameters: tuple = None) -> list[tuple]:
|
|
1205
|
+
"""
|
|
1206
|
+
SQLを実行して全行取得
|
|
1207
|
+
|
|
1208
|
+
Args:
|
|
1209
|
+
sql: 実行するSQL文
|
|
1210
|
+
parameters: SQLのパラメータ
|
|
1211
|
+
|
|
1212
|
+
Returns:
|
|
1213
|
+
全行の結果(tupleのリスト)
|
|
1214
|
+
|
|
1215
|
+
Example:
|
|
1216
|
+
>>> rows = db.fetch_all("SELECT key, value FROM data WHERE key LIKE ?", ("user%",))
|
|
1217
|
+
>>> for key, value in rows:
|
|
1218
|
+
... print(key, value)
|
|
1219
|
+
"""
|
|
1220
|
+
cursor = self.execute(sql, parameters)
|
|
1221
|
+
return cursor.fetchall()
|
|
1222
|
+
|
|
1223
|
+
# ==================== SQLite Wrapper Functions ====================
|
|
1224
|
+
|
|
1225
|
+
def create_table(self, table_name: str, columns: dict, if_not_exists: bool = True, primary_key: str = None) -> None:
|
|
1226
|
+
"""
|
|
1227
|
+
テーブルを作成
|
|
1228
|
+
|
|
1229
|
+
Args:
|
|
1230
|
+
table_name: テーブル名
|
|
1231
|
+
columns: カラム定義のdict(カラム名: SQL型)
|
|
1232
|
+
if_not_exists: Trueの場合、存在しない場合のみ作成
|
|
1233
|
+
primary_key: プライマリキーのカラム名(Noneの場合は指定なし)
|
|
1234
|
+
|
|
1235
|
+
Example:
|
|
1236
|
+
>>> db.create_table("users", {
|
|
1237
|
+
... "id": "INTEGER PRIMARY KEY",
|
|
1238
|
+
... "name": "TEXT NOT NULL",
|
|
1239
|
+
... "email": "TEXT UNIQUE",
|
|
1240
|
+
... "age": "INTEGER"
|
|
1241
|
+
... })
|
|
1242
|
+
>>> db.create_table("posts", {
|
|
1243
|
+
... "id": "INTEGER",
|
|
1244
|
+
... "title": "TEXT",
|
|
1245
|
+
... "content": "TEXT"
|
|
1246
|
+
... }, primary_key="id")
|
|
1247
|
+
"""
|
|
1248
|
+
if_not_exists_clause = "IF NOT EXISTS " if if_not_exists else ""
|
|
1249
|
+
safe_table_name = self._sanitize_identifier(table_name)
|
|
1250
|
+
|
|
1251
|
+
column_defs = []
|
|
1252
|
+
for col_name, col_type in columns.items():
|
|
1253
|
+
safe_col_name = self._sanitize_identifier(col_name)
|
|
1254
|
+
column_defs.append(f"{safe_col_name} {col_type}")
|
|
1255
|
+
|
|
1256
|
+
if primary_key:
|
|
1257
|
+
safe_pk = self._sanitize_identifier(primary_key)
|
|
1258
|
+
if not any(primary_key.upper() in col.upper() and "PRIMARY KEY" in col.upper() for col in column_defs):
|
|
1259
|
+
column_defs.append(f"PRIMARY KEY ({safe_pk})")
|
|
1260
|
+
|
|
1261
|
+
columns_sql = ", ".join(column_defs)
|
|
1262
|
+
sql = f"CREATE TABLE {if_not_exists_clause}{safe_table_name} ({columns_sql})"
|
|
1263
|
+
|
|
1264
|
+
self.execute(sql)
|
|
1265
|
+
|
|
1266
|
+
def create_index(
|
|
1267
|
+
self, index_name: str, table_name: str, columns: list[str], unique: bool = False, if_not_exists: bool = True
|
|
1268
|
+
) -> None:
|
|
1269
|
+
"""
|
|
1270
|
+
インデックスを作成
|
|
1271
|
+
|
|
1272
|
+
Args:
|
|
1273
|
+
index_name: インデックス名
|
|
1274
|
+
table_name: テーブル名
|
|
1275
|
+
columns: インデックスを作成するカラムのリスト
|
|
1276
|
+
unique: Trueの場合、ユニークインデックスを作成
|
|
1277
|
+
if_not_exists: Trueの場合、存在しない場合のみ作成
|
|
1278
|
+
|
|
1279
|
+
Example:
|
|
1280
|
+
>>> db.create_index("idx_users_email", "users", ["email"], unique=True)
|
|
1281
|
+
>>> db.create_index("idx_posts_user", "posts", ["user_id", "created_at"])
|
|
1282
|
+
"""
|
|
1283
|
+
unique_clause = "UNIQUE " if unique else ""
|
|
1284
|
+
if_not_exists_clause = "IF NOT EXISTS " if if_not_exists else ""
|
|
1285
|
+
safe_index_name = self._sanitize_identifier(index_name)
|
|
1286
|
+
safe_table_name = self._sanitize_identifier(table_name)
|
|
1287
|
+
safe_columns = [self._sanitize_identifier(col) for col in columns]
|
|
1288
|
+
columns_sql = ", ".join(safe_columns)
|
|
1289
|
+
|
|
1290
|
+
sql = (
|
|
1291
|
+
f"CREATE {unique_clause}INDEX {if_not_exists_clause}{safe_index_name} ON {safe_table_name} ({columns_sql})"
|
|
1292
|
+
)
|
|
1293
|
+
self.execute(sql)
|
|
1294
|
+
|
|
1295
|
+
def query(
|
|
1296
|
+
self,
|
|
1297
|
+
table_name: str = None,
|
|
1298
|
+
columns: list[str] = None,
|
|
1299
|
+
where: str = None,
|
|
1300
|
+
parameters: tuple = None,
|
|
1301
|
+
order_by: str = None,
|
|
1302
|
+
limit: int = None,
|
|
1303
|
+
strict_sql_validation: bool = None,
|
|
1304
|
+
allowed_sql_functions: list[str] = None,
|
|
1305
|
+
forbidden_sql_functions: list[str] = None,
|
|
1306
|
+
override_allowed: bool = False,
|
|
1307
|
+
) -> list[dict]:
|
|
1308
|
+
"""
|
|
1309
|
+
シンプルなSELECTクエリを実行
|
|
1310
|
+
|
|
1311
|
+
Args:
|
|
1312
|
+
table_name: テーブル名(Noneの場合はデフォルトテーブル)
|
|
1313
|
+
columns: 取得するカラムのリスト(Noneの場合は全カラム)
|
|
1314
|
+
where: WHERE句の条件(パラメータバインディング使用推奨)
|
|
1315
|
+
parameters: WHERE句のパラメータ
|
|
1316
|
+
order_by: ORDER BY句
|
|
1317
|
+
limit: LIMIT句
|
|
1318
|
+
strict_sql_validation: Trueの場合、未許可の関数等を含むクエリを拒否
|
|
1319
|
+
allowed_sql_functions: このクエリで一時的に許可するSQL関数のリスト
|
|
1320
|
+
forbidden_sql_functions: このクエリで一時的に禁止するSQL関数のリスト
|
|
1321
|
+
override_allowed: Trueの場合、インスタンス許可設定を無視
|
|
1322
|
+
|
|
1323
|
+
Returns:
|
|
1324
|
+
結果のリスト(各行はdict)
|
|
1325
|
+
|
|
1326
|
+
Example:
|
|
1327
|
+
>>> # デフォルトテーブルから全データ取得
|
|
1328
|
+
>>> results = db.query()
|
|
1329
|
+
|
|
1330
|
+
>>> # 条件付き検索
|
|
1331
|
+
>>> results = db.query(
|
|
1332
|
+
... table_name="users",
|
|
1333
|
+
... columns=["id", "name", "email"],
|
|
1334
|
+
... where="age > ?",
|
|
1335
|
+
... parameters=(20,),
|
|
1336
|
+
... order_by="name ASC",
|
|
1337
|
+
... limit=10
|
|
1338
|
+
... )
|
|
1339
|
+
"""
|
|
1340
|
+
if table_name is None:
|
|
1341
|
+
table_name = self._table
|
|
1342
|
+
|
|
1343
|
+
safe_table_name = self._sanitize_identifier(table_name)
|
|
1344
|
+
|
|
1345
|
+
# バリデーション
|
|
1346
|
+
self._validate_expression(
|
|
1347
|
+
where,
|
|
1348
|
+
strict_sql_validation,
|
|
1349
|
+
allowed_sql_functions,
|
|
1350
|
+
forbidden_sql_functions,
|
|
1351
|
+
override_allowed,
|
|
1352
|
+
context="where",
|
|
1353
|
+
)
|
|
1354
|
+
self._validate_expression(
|
|
1355
|
+
order_by,
|
|
1356
|
+
strict_sql_validation,
|
|
1357
|
+
allowed_sql_functions,
|
|
1358
|
+
forbidden_sql_functions,
|
|
1359
|
+
override_allowed,
|
|
1360
|
+
context="order_by",
|
|
1361
|
+
)
|
|
1362
|
+
if columns:
|
|
1363
|
+
for col in columns:
|
|
1364
|
+
# 関数使用の可能性を考慮して識別子サニタイズは行わないがバリデーションは行う
|
|
1365
|
+
self._validate_expression(
|
|
1366
|
+
col,
|
|
1367
|
+
strict_sql_validation,
|
|
1368
|
+
allowed_sql_functions,
|
|
1369
|
+
forbidden_sql_functions,
|
|
1370
|
+
override_allowed,
|
|
1371
|
+
context="column",
|
|
1372
|
+
)
|
|
1373
|
+
|
|
1374
|
+
# カラム指定
|
|
1375
|
+
if columns is None:
|
|
1376
|
+
columns_sql = "*"
|
|
1377
|
+
# カラム名は後でPRAGMAから取得
|
|
1378
|
+
else:
|
|
1379
|
+
# 識別子(カラム名のみ)の場合はサニタイズ、式の場合はそのまま(バリデーション済み)
|
|
1380
|
+
safe_cols = []
|
|
1381
|
+
for col in columns:
|
|
1382
|
+
if IDENTIFIER_PATTERN.match(col):
|
|
1383
|
+
safe_cols.append(self._sanitize_identifier(col))
|
|
1384
|
+
else:
|
|
1385
|
+
safe_cols.append(col)
|
|
1386
|
+
columns_sql = ", ".join(safe_cols)
|
|
1387
|
+
|
|
1388
|
+
# Validate limit is an integer and non-negative if provided
|
|
1389
|
+
if limit is not None:
|
|
1390
|
+
if not isinstance(limit, int):
|
|
1391
|
+
raise ValueError(f"limit must be an integer, got {type(limit).__name__}")
|
|
1392
|
+
if limit < 0:
|
|
1393
|
+
raise ValueError("limit must be non-negative")
|
|
1394
|
+
|
|
1395
|
+
# SQL構築
|
|
1396
|
+
sql = f"SELECT {columns_sql} FROM {safe_table_name}" # nosec
|
|
1397
|
+
|
|
1398
|
+
if where:
|
|
1399
|
+
sql += f" WHERE {where}"
|
|
1400
|
+
|
|
1401
|
+
if order_by:
|
|
1402
|
+
sql += f" ORDER BY {order_by}"
|
|
1403
|
+
|
|
1404
|
+
if limit is not None:
|
|
1405
|
+
sql += f" LIMIT {limit}"
|
|
1406
|
+
|
|
1407
|
+
# 実行
|
|
1408
|
+
cursor = self.execute(sql, parameters)
|
|
1409
|
+
|
|
1410
|
+
# カラム名取得
|
|
1411
|
+
if columns is None:
|
|
1412
|
+
# 全カラムの場合、テーブル情報から取得
|
|
1413
|
+
pragma_cursor = self.execute(f"PRAGMA table_info({safe_table_name})")
|
|
1414
|
+
col_names = [row[1] for row in pragma_cursor]
|
|
1415
|
+
else:
|
|
1416
|
+
# Extract aliases from AS clauses, similar to query_with_pagination
|
|
1417
|
+
col_names = []
|
|
1418
|
+
for col in columns:
|
|
1419
|
+
parts = re.split(r"\s+as\s+", col, flags=re.IGNORECASE)
|
|
1420
|
+
if len(parts) > 1:
|
|
1421
|
+
# Use the alias (after AS)
|
|
1422
|
+
col_names.append(parts[-1].strip().strip('"').strip("'"))
|
|
1423
|
+
else:
|
|
1424
|
+
# Use the column expression as-is
|
|
1425
|
+
col_names.append(col.strip())
|
|
1426
|
+
|
|
1427
|
+
# 結果をdictのリストに変換
|
|
1428
|
+
results = []
|
|
1429
|
+
for row in cursor:
|
|
1430
|
+
results.append(dict(zip(col_names, row)))
|
|
1431
|
+
|
|
1432
|
+
return results
|
|
1433
|
+
|
|
1434
|
+
def table_exists(self, table_name: str) -> bool:
|
|
1435
|
+
"""
|
|
1436
|
+
テーブルの存在確認
|
|
1437
|
+
|
|
1438
|
+
Args:
|
|
1439
|
+
table_name: テーブル名
|
|
1440
|
+
|
|
1441
|
+
Returns:
|
|
1442
|
+
存在する場合True、しない場合False
|
|
1443
|
+
|
|
1444
|
+
Example:
|
|
1445
|
+
>>> if db.table_exists("users"):
|
|
1446
|
+
... print("users table exists")
|
|
1447
|
+
"""
|
|
1448
|
+
cursor = self.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table_name,))
|
|
1449
|
+
return cursor.fetchone() is not None
|
|
1450
|
+
|
|
1451
|
+
def list_tables(self) -> list[str]:
|
|
1452
|
+
"""
|
|
1453
|
+
データベース内の全テーブル一覧を取得
|
|
1454
|
+
|
|
1455
|
+
Returns:
|
|
1456
|
+
テーブル名のリスト
|
|
1457
|
+
|
|
1458
|
+
Example:
|
|
1459
|
+
>>> tables = db.list_tables()
|
|
1460
|
+
>>> print(tables) # ['data', 'users', 'posts']
|
|
1461
|
+
"""
|
|
1462
|
+
cursor = self.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
|
|
1463
|
+
return [row[0] for row in cursor]
|
|
1464
|
+
|
|
1465
|
+
def drop_table(self, table_name: str, if_exists: bool = True) -> None:
|
|
1466
|
+
"""
|
|
1467
|
+
テーブルを削除
|
|
1468
|
+
|
|
1469
|
+
Args:
|
|
1470
|
+
table_name: テーブル名
|
|
1471
|
+
if_exists: Trueの場合、存在する場合のみ削除(エラーを防ぐ)
|
|
1472
|
+
|
|
1473
|
+
Example:
|
|
1474
|
+
>>> db.drop_table("old_table")
|
|
1475
|
+
>>> db.drop_table("temp", if_exists=True)
|
|
1476
|
+
"""
|
|
1477
|
+
if_exists_clause = "IF EXISTS " if if_exists else ""
|
|
1478
|
+
safe_table_name = self._sanitize_identifier(table_name)
|
|
1479
|
+
sql = f"DROP TABLE {if_exists_clause}{safe_table_name}"
|
|
1480
|
+
self.execute(sql)
|
|
1481
|
+
|
|
1482
|
+
def drop_index(self, index_name: str, if_exists: bool = True) -> None:
|
|
1483
|
+
"""
|
|
1484
|
+
インデックスを削除
|
|
1485
|
+
|
|
1486
|
+
Args:
|
|
1487
|
+
index_name: インデックス名
|
|
1488
|
+
if_exists: Trueの場合、存在する場合のみ削除
|
|
1489
|
+
|
|
1490
|
+
Example:
|
|
1491
|
+
>>> db.drop_index("idx_users_email")
|
|
1492
|
+
"""
|
|
1493
|
+
if_exists_clause = "IF EXISTS " if if_exists else ""
|
|
1494
|
+
safe_index_name = self._sanitize_identifier(index_name)
|
|
1495
|
+
sql = f"DROP INDEX {if_exists_clause}{safe_index_name}"
|
|
1496
|
+
self.execute(sql)
|
|
1497
|
+
|
|
1498
|
+
def alter_table_add_column(self, table_name: str, column_name: str, column_type: str, default: Any = None) -> None:
|
|
1499
|
+
"""
|
|
1500
|
+
既存テーブルにカラムを追加
|
|
1501
|
+
|
|
1502
|
+
Args:
|
|
1503
|
+
table_name: テーブル名
|
|
1504
|
+
column_name: カラム名
|
|
1505
|
+
column_type: カラムの型(SQL型)
|
|
1506
|
+
default: デフォルト値(Noneの場合は指定なし)
|
|
1507
|
+
|
|
1508
|
+
Example:
|
|
1509
|
+
>>> db.alter_table_add_column("users", "phone", "TEXT")
|
|
1510
|
+
>>> db.alter_table_add_column("users", "status", "TEXT", default="'active'")
|
|
1511
|
+
"""
|
|
1512
|
+
safe_table_name = self._sanitize_identifier(table_name)
|
|
1513
|
+
safe_column_name = self._sanitize_identifier(column_name)
|
|
1514
|
+
# column_type is a SQL type string - validate it doesn't contain dangerous characters
|
|
1515
|
+
# Also check for closing parenthesis which could break out of ALTER TABLE structure
|
|
1516
|
+
if any(c in column_type for c in [";", "'", ")"]) or "--" in column_type or "/*" in column_type:
|
|
1517
|
+
raise ValueError(f"Invalid or dangerous column type: {column_type}")
|
|
1518
|
+
|
|
1519
|
+
sql = f"ALTER TABLE {safe_table_name} ADD COLUMN {safe_column_name} {column_type}"
|
|
1520
|
+
if default is not None:
|
|
1521
|
+
# For default values: if it's a string, ensure it's properly quoted and escaped
|
|
1522
|
+
if isinstance(default, str):
|
|
1523
|
+
# Strip leading/trailing single quotes if present, then escape and re-quote
|
|
1524
|
+
stripped = default
|
|
1525
|
+
if stripped.startswith("'") and stripped.endswith("'") and len(stripped) >= 2:
|
|
1526
|
+
stripped = stripped[1:-1]
|
|
1527
|
+
# Escape single quotes for SQL string literal (double them: ' becomes '')
|
|
1528
|
+
escaped_default = stripped.replace("'", "''")
|
|
1529
|
+
default = f"'{escaped_default}'"
|
|
1530
|
+
sql += f" DEFAULT {default}"
|
|
1531
|
+
self.execute(sql)
|
|
1532
|
+
|
|
1533
|
+
def get_table_schema(self, table_name: str) -> list[dict]:
|
|
1534
|
+
"""
|
|
1535
|
+
テーブル構造を取得
|
|
1536
|
+
|
|
1537
|
+
Args:
|
|
1538
|
+
table_name: テーブル名
|
|
1539
|
+
|
|
1540
|
+
Returns:
|
|
1541
|
+
カラム情報のリスト(各カラムはdict)
|
|
1542
|
+
|
|
1543
|
+
Example:
|
|
1544
|
+
>>> schema = db.get_table_schema("users")
|
|
1545
|
+
>>> for col in schema:
|
|
1546
|
+
... print(f"{col['name']}: {col['type']}")
|
|
1547
|
+
"""
|
|
1548
|
+
safe_table_name = self._sanitize_identifier(table_name)
|
|
1549
|
+
cursor = self.execute(f"PRAGMA table_info({safe_table_name})")
|
|
1550
|
+
columns = []
|
|
1551
|
+
for row in cursor:
|
|
1552
|
+
columns.append(
|
|
1553
|
+
{
|
|
1554
|
+
"cid": row[0],
|
|
1555
|
+
"name": row[1],
|
|
1556
|
+
"type": row[2],
|
|
1557
|
+
"notnull": bool(row[3]),
|
|
1558
|
+
"default_value": row[4],
|
|
1559
|
+
"pk": bool(row[5]),
|
|
1560
|
+
}
|
|
1561
|
+
)
|
|
1562
|
+
return columns
|
|
1563
|
+
|
|
1564
|
+
def list_indexes(self, table_name: str = None) -> list[dict]:
|
|
1565
|
+
"""
|
|
1566
|
+
インデックス一覧を取得
|
|
1567
|
+
|
|
1568
|
+
Args:
|
|
1569
|
+
table_name: テーブル名(Noneの場合は全インデックス)
|
|
1570
|
+
|
|
1571
|
+
Returns:
|
|
1572
|
+
インデックス情報のリスト
|
|
1573
|
+
|
|
1574
|
+
Example:
|
|
1575
|
+
>>> indexes = db.list_indexes("users")
|
|
1576
|
+
>>> for idx in indexes:
|
|
1577
|
+
... print(f"{idx['name']}: {idx['columns']}")
|
|
1578
|
+
"""
|
|
1579
|
+
if table_name:
|
|
1580
|
+
cursor = self.execute(
|
|
1581
|
+
"SELECT name, tbl_name, sql FROM sqlite_master WHERE type='index' AND tbl_name=? ORDER BY name",
|
|
1582
|
+
(table_name,),
|
|
1583
|
+
)
|
|
1584
|
+
else:
|
|
1585
|
+
cursor = self.execute("SELECT name, tbl_name, sql FROM sqlite_master WHERE type='index' ORDER BY name")
|
|
1586
|
+
|
|
1587
|
+
indexes = []
|
|
1588
|
+
for row in cursor:
|
|
1589
|
+
if row[0] and not row[0].startswith("sqlite_"): # Skip auto-created indexes
|
|
1590
|
+
indexes.append({"name": row[0], "table": row[1], "sql": row[2]})
|
|
1591
|
+
return indexes
|
|
1592
|
+
|
|
1593
|
+
# ==================== Data Operation Wrappers ====================
|
|
1594
|
+
|
|
1595
|
+
def sql_insert(self, table_name: str, data: dict) -> int:
|
|
1596
|
+
"""
|
|
1597
|
+
dictから直接INSERT
|
|
1598
|
+
|
|
1599
|
+
Args:
|
|
1600
|
+
table_name: テーブル名
|
|
1601
|
+
data: カラム名と値のdict
|
|
1602
|
+
|
|
1603
|
+
Returns:
|
|
1604
|
+
挿入されたROWID
|
|
1605
|
+
|
|
1606
|
+
Example:
|
|
1607
|
+
>>> rowid = db.sql_insert("users", {
|
|
1608
|
+
... "name": "Alice",
|
|
1609
|
+
... "email": "alice@example.com",
|
|
1610
|
+
... "age": 25
|
|
1611
|
+
... })
|
|
1612
|
+
"""
|
|
1613
|
+
safe_table_name = self._sanitize_identifier(table_name)
|
|
1614
|
+
safe_columns = [self._sanitize_identifier(col) for col in data.keys()]
|
|
1615
|
+
values = list(data.values())
|
|
1616
|
+
placeholders = ", ".join(["?"] * len(values))
|
|
1617
|
+
columns_sql = ", ".join(safe_columns)
|
|
1618
|
+
|
|
1619
|
+
sql = f"INSERT INTO {safe_table_name} ({columns_sql}) VALUES ({placeholders})" # nosec
|
|
1620
|
+
self.execute(sql, tuple(values))
|
|
1621
|
+
|
|
1622
|
+
return self.get_last_insert_rowid()
|
|
1623
|
+
|
|
1624
|
+
def sql_update(self, table_name: str, data: dict, where: str, parameters: tuple = None) -> int:
|
|
1625
|
+
"""
|
|
1626
|
+
dictとwhere条件でUPDATE
|
|
1627
|
+
|
|
1628
|
+
Args:
|
|
1629
|
+
table_name: テーブル名
|
|
1630
|
+
data: 更新するカラム名と値のdict
|
|
1631
|
+
where: WHERE句の条件
|
|
1632
|
+
parameters: WHERE句のパラメータ
|
|
1633
|
+
|
|
1634
|
+
Returns:
|
|
1635
|
+
更新された行数
|
|
1636
|
+
|
|
1637
|
+
Example:
|
|
1638
|
+
>>> count = db.sql_update("users",
|
|
1639
|
+
... {"age": 26, "status": "active"},
|
|
1640
|
+
... "name = ?",
|
|
1641
|
+
... ("Alice",)
|
|
1642
|
+
... )
|
|
1643
|
+
"""
|
|
1644
|
+
safe_table_name = self._sanitize_identifier(table_name)
|
|
1645
|
+
safe_set_items = [f"{self._sanitize_identifier(col)} = ?" for col in data.keys()]
|
|
1646
|
+
set_clause = ", ".join(safe_set_items)
|
|
1647
|
+
values = list(data.values())
|
|
1648
|
+
|
|
1649
|
+
sql = f"UPDATE {safe_table_name} SET {set_clause} WHERE {where}" # nosec
|
|
1650
|
+
|
|
1651
|
+
if parameters:
|
|
1652
|
+
values.extend(parameters)
|
|
1653
|
+
|
|
1654
|
+
self.execute(sql, tuple(values))
|
|
1655
|
+
return self._connection.changes()
|
|
1656
|
+
|
|
1657
|
+
def sql_delete(self, table_name: str, where: str, parameters: tuple = None) -> int:
|
|
1658
|
+
"""
|
|
1659
|
+
where条件でDELETE
|
|
1660
|
+
|
|
1661
|
+
Args:
|
|
1662
|
+
table_name: テーブル名
|
|
1663
|
+
where: WHERE句の条件
|
|
1664
|
+
parameters: WHERE句のパラメータ
|
|
1665
|
+
|
|
1666
|
+
Returns:
|
|
1667
|
+
削除された行数
|
|
1668
|
+
|
|
1669
|
+
Example:
|
|
1670
|
+
>>> count = db.sql_delete("users", "age < ?", (18,))
|
|
1671
|
+
"""
|
|
1672
|
+
safe_table_name = self._sanitize_identifier(table_name)
|
|
1673
|
+
sql = f"DELETE FROM {safe_table_name} WHERE {where}" # nosec
|
|
1674
|
+
self.execute(sql, parameters)
|
|
1675
|
+
return self._connection.changes()
|
|
1676
|
+
|
|
1677
|
+
def upsert(self, table_name: str, data: dict, conflict_columns: list[str] = None) -> int:
|
|
1678
|
+
"""
|
|
1679
|
+
INSERT OR REPLACE の簡易版(upsert)
|
|
1680
|
+
|
|
1681
|
+
Args:
|
|
1682
|
+
table_name: テーブル名
|
|
1683
|
+
data: カラム名と値のdict
|
|
1684
|
+
conflict_columns: 競合判定に使用するカラム(Noneの場合はINSERT OR REPLACE)
|
|
1685
|
+
|
|
1686
|
+
Returns:
|
|
1687
|
+
挿入/更新されたROWID
|
|
1688
|
+
|
|
1689
|
+
Example:
|
|
1690
|
+
>>> # 単純なINSERT OR REPLACE
|
|
1691
|
+
>>> db.upsert("users", {"id": 1, "name": "Alice", "age": 25})
|
|
1692
|
+
|
|
1693
|
+
>>> # ON CONFLICT句を使用
|
|
1694
|
+
>>> db.upsert("users",
|
|
1695
|
+
... {"email": "alice@example.com", "name": "Alice", "age": 26},
|
|
1696
|
+
... conflict_columns=["email"]
|
|
1697
|
+
... )
|
|
1698
|
+
"""
|
|
1699
|
+
safe_table_name = self._sanitize_identifier(table_name)
|
|
1700
|
+
safe_columns = [self._sanitize_identifier(col) for col in data.keys()]
|
|
1701
|
+
values = list(data.values())
|
|
1702
|
+
placeholders = ", ".join(["?"] * len(values))
|
|
1703
|
+
columns_sql = ", ".join(safe_columns)
|
|
1704
|
+
|
|
1705
|
+
if conflict_columns:
|
|
1706
|
+
# ON CONFLICT を使用
|
|
1707
|
+
safe_conflict_cols = [self._sanitize_identifier(col) for col in conflict_columns]
|
|
1708
|
+
conflict_cols_sql = ", ".join(safe_conflict_cols)
|
|
1709
|
+
|
|
1710
|
+
update_items = [
|
|
1711
|
+
f"{self._sanitize_identifier(col)} = excluded.{self._sanitize_identifier(col)}"
|
|
1712
|
+
for col in data.keys()
|
|
1713
|
+
if col not in conflict_columns
|
|
1714
|
+
]
|
|
1715
|
+
|
|
1716
|
+
if update_items:
|
|
1717
|
+
update_clause = ", ".join(update_items)
|
|
1718
|
+
else:
|
|
1719
|
+
# 全カラムが競合カラムの場合は、何もしない(既存データを保持)
|
|
1720
|
+
sql = f"INSERT INTO {safe_table_name} ({columns_sql}) VALUES ({placeholders}) " # nosec
|
|
1721
|
+
sql += f"ON CONFLICT({conflict_cols_sql}) DO NOTHING" # nosec
|
|
1722
|
+
self.execute(sql, tuple(values))
|
|
1723
|
+
# When DO NOTHING is triggered, no row is inserted, return 0
|
|
1724
|
+
# Check only the most recent operation's change count
|
|
1725
|
+
if self._connection.changes() == 0:
|
|
1726
|
+
return 0
|
|
1727
|
+
return self.get_last_insert_rowid()
|
|
1728
|
+
|
|
1729
|
+
sql = f"INSERT INTO {safe_table_name} ({columns_sql}) VALUES ({placeholders}) " # nosec
|
|
1730
|
+
sql += f"ON CONFLICT({conflict_cols_sql}) DO UPDATE SET {update_clause}" # nosec
|
|
1731
|
+
else:
|
|
1732
|
+
# INSERT OR REPLACE
|
|
1733
|
+
sql = f"INSERT OR REPLACE INTO {safe_table_name} ({columns_sql}) VALUES ({placeholders})" # nosec
|
|
1734
|
+
|
|
1735
|
+
self.execute(sql, tuple(values))
|
|
1736
|
+
return self.get_last_insert_rowid()
|
|
1737
|
+
|
|
1738
|
+
def count(
|
|
1739
|
+
self,
|
|
1740
|
+
table_name: str = None,
|
|
1741
|
+
where: str = None,
|
|
1742
|
+
parameters: tuple = None,
|
|
1743
|
+
strict_sql_validation: bool = None,
|
|
1744
|
+
allowed_sql_functions: list[str] = None,
|
|
1745
|
+
forbidden_sql_functions: list[str] = None,
|
|
1746
|
+
override_allowed: bool = False,
|
|
1747
|
+
) -> int:
|
|
1748
|
+
"""
|
|
1749
|
+
レコード数を取得
|
|
1750
|
+
|
|
1751
|
+
Args:
|
|
1752
|
+
table_name: テーブル名(Noneの場合はデフォルトテーブル)
|
|
1753
|
+
where: WHERE句の条件(オプション)
|
|
1754
|
+
parameters: WHERE句のパラメータ
|
|
1755
|
+
strict_sql_validation: Trueの場合、未許可の関数等を含むクエリを拒否
|
|
1756
|
+
allowed_sql_functions: このクエリで一時的に許可するSQL関数のリスト
|
|
1757
|
+
forbidden_sql_functions: このクエリで一時的に禁止するSQL関数のリスト
|
|
1758
|
+
override_allowed: Trueの場合、インスタンス許可設定を無視
|
|
1759
|
+
|
|
1760
|
+
Example:
|
|
1761
|
+
>>> total = db.count("users")
|
|
1762
|
+
>>> adults = db.count("users", "age >= ?", (18,))
|
|
1763
|
+
"""
|
|
1764
|
+
if table_name is None:
|
|
1765
|
+
table_name = self._table
|
|
1766
|
+
|
|
1767
|
+
safe_table_name = self._sanitize_identifier(table_name)
|
|
1768
|
+
|
|
1769
|
+
# バリデーション
|
|
1770
|
+
self._validate_expression(
|
|
1771
|
+
where, strict_sql_validation, allowed_sql_functions, forbidden_sql_functions, override_allowed
|
|
1772
|
+
)
|
|
1773
|
+
|
|
1774
|
+
sql = f"SELECT COUNT(*) FROM {safe_table_name}" # nosec
|
|
1775
|
+
if where:
|
|
1776
|
+
sql += f" WHERE {where}"
|
|
1777
|
+
|
|
1778
|
+
cursor = self.execute(sql, parameters)
|
|
1779
|
+
return cursor.fetchone()[0]
|
|
1780
|
+
|
|
1781
|
+
def exists(self, table_name: str, where: str, parameters: tuple = None) -> bool:
|
|
1782
|
+
"""
|
|
1783
|
+
レコードの存在確認
|
|
1784
|
+
|
|
1785
|
+
Args:
|
|
1786
|
+
table_name: テーブル名
|
|
1787
|
+
where: WHERE句の条件
|
|
1788
|
+
parameters: WHERE句のパラメータ
|
|
1789
|
+
|
|
1790
|
+
Returns:
|
|
1791
|
+
存在する場合True
|
|
1792
|
+
|
|
1793
|
+
Example:
|
|
1794
|
+
>>> if db.exists("users", "email = ?", ("alice@example.com",)):
|
|
1795
|
+
... print("User exists")
|
|
1796
|
+
"""
|
|
1797
|
+
safe_table_name = self._sanitize_identifier(table_name)
|
|
1798
|
+
sql = f"SELECT EXISTS(SELECT 1 FROM {safe_table_name} WHERE {where})" # nosec
|
|
1799
|
+
cursor = self.execute(sql, parameters)
|
|
1800
|
+
return bool(cursor.fetchone()[0])
|
|
1801
|
+
|
|
1802
|
+
# ==================== Query Extensions ====================
|
|
1803
|
+
|
|
1804
|
+
def query_with_pagination(
|
|
1805
|
+
self,
|
|
1806
|
+
table_name: str = None,
|
|
1807
|
+
columns: list[str] = None,
|
|
1808
|
+
where: str = None,
|
|
1809
|
+
parameters: tuple = None,
|
|
1810
|
+
order_by: str = None,
|
|
1811
|
+
limit: int = None,
|
|
1812
|
+
offset: int = None,
|
|
1813
|
+
group_by: str = None,
|
|
1814
|
+
strict_sql_validation: bool = None,
|
|
1815
|
+
allowed_sql_functions: list[str] = None,
|
|
1816
|
+
forbidden_sql_functions: list[str] = None,
|
|
1817
|
+
override_allowed: bool = False,
|
|
1818
|
+
) -> list[dict]:
|
|
1819
|
+
"""
|
|
1820
|
+
拡張されたクエリ(offset、group_by対応)
|
|
1821
|
+
|
|
1822
|
+
Args:
|
|
1823
|
+
table_name: テーブル名
|
|
1824
|
+
columns: 取得するカラム
|
|
1825
|
+
where: WHERE句
|
|
1826
|
+
parameters: パラメータ
|
|
1827
|
+
order_by: ORDER BY句
|
|
1828
|
+
limit: LIMIT句
|
|
1829
|
+
offset: OFFSET句(ページネーション用)
|
|
1830
|
+
group_by: GROUP BY句
|
|
1831
|
+
strict_sql_validation: Trueの場合、未許可の関数等を含むクエリを拒否
|
|
1832
|
+
allowed_sql_functions: このクエリで一時的に許可するSQL関数のリスト
|
|
1833
|
+
forbidden_sql_functions: このクエリで一時的に禁止するSQL関数のリスト
|
|
1834
|
+
override_allowed: Trueの場合、インスタンス許可設定を無視
|
|
1835
|
+
|
|
1836
|
+
Returns:
|
|
1837
|
+
結果のリスト
|
|
1838
|
+
|
|
1839
|
+
Example:
|
|
1840
|
+
>>> # ページネーション
|
|
1841
|
+
>>> page2 = db.query_with_pagination("users",
|
|
1842
|
+
... limit=10, offset=10, order_by="id ASC")
|
|
1843
|
+
|
|
1844
|
+
>>> # グループ集計
|
|
1845
|
+
>>> stats = db.query_with_pagination("orders",
|
|
1846
|
+
... columns=["user_id", "COUNT(*) as order_count"],
|
|
1847
|
+
... group_by="user_id"
|
|
1848
|
+
... )
|
|
1849
|
+
"""
|
|
1850
|
+
if table_name is None:
|
|
1851
|
+
table_name = self._table
|
|
1852
|
+
|
|
1853
|
+
safe_table_name = self._sanitize_identifier(table_name)
|
|
1854
|
+
|
|
1855
|
+
# バリデーション
|
|
1856
|
+
self._validate_expression(
|
|
1857
|
+
where,
|
|
1858
|
+
strict_sql_validation,
|
|
1859
|
+
allowed_sql_functions,
|
|
1860
|
+
forbidden_sql_functions,
|
|
1861
|
+
override_allowed,
|
|
1862
|
+
context="where",
|
|
1863
|
+
)
|
|
1864
|
+
self._validate_expression(
|
|
1865
|
+
order_by,
|
|
1866
|
+
strict_sql_validation,
|
|
1867
|
+
allowed_sql_functions,
|
|
1868
|
+
forbidden_sql_functions,
|
|
1869
|
+
override_allowed,
|
|
1870
|
+
context="order_by",
|
|
1871
|
+
)
|
|
1872
|
+
self._validate_expression(
|
|
1873
|
+
group_by,
|
|
1874
|
+
strict_sql_validation,
|
|
1875
|
+
allowed_sql_functions,
|
|
1876
|
+
forbidden_sql_functions,
|
|
1877
|
+
override_allowed,
|
|
1878
|
+
context="group_by",
|
|
1879
|
+
)
|
|
1880
|
+
if columns:
|
|
1881
|
+
for col in columns:
|
|
1882
|
+
self._validate_expression(
|
|
1883
|
+
col,
|
|
1884
|
+
strict_sql_validation,
|
|
1885
|
+
allowed_sql_functions,
|
|
1886
|
+
forbidden_sql_functions,
|
|
1887
|
+
override_allowed,
|
|
1888
|
+
context="column",
|
|
1889
|
+
)
|
|
1890
|
+
|
|
1891
|
+
# Validate limit and offset are non-negative integers if provided
|
|
1892
|
+
if limit is not None:
|
|
1893
|
+
if not isinstance(limit, int):
|
|
1894
|
+
raise ValueError(f"limit must be an integer, got {type(limit).__name__}")
|
|
1895
|
+
if limit < 0:
|
|
1896
|
+
raise ValueError("limit must be non-negative")
|
|
1897
|
+
|
|
1898
|
+
if offset is not None:
|
|
1899
|
+
if not isinstance(offset, int):
|
|
1900
|
+
raise ValueError(f"offset must be an integer, got {type(offset).__name__}")
|
|
1901
|
+
if offset < 0:
|
|
1902
|
+
raise ValueError("offset must be non-negative")
|
|
1903
|
+
|
|
1904
|
+
# カラム指定
|
|
1905
|
+
if columns is None:
|
|
1906
|
+
columns_sql = "*"
|
|
1907
|
+
else:
|
|
1908
|
+
# 識別子(カラム名のみ)の場合はサニタイズ、式の場合はそのまま(バリデーション済み)
|
|
1909
|
+
safe_cols = []
|
|
1910
|
+
for col in columns:
|
|
1911
|
+
if IDENTIFIER_PATTERN.match(col):
|
|
1912
|
+
safe_cols.append(self._sanitize_identifier(col))
|
|
1913
|
+
else:
|
|
1914
|
+
safe_cols.append(col)
|
|
1915
|
+
columns_sql = ", ".join(safe_cols)
|
|
1916
|
+
|
|
1917
|
+
# SQL構築
|
|
1918
|
+
sql = f"SELECT {columns_sql} FROM {safe_table_name}" # nosec
|
|
1919
|
+
|
|
1920
|
+
if where:
|
|
1921
|
+
sql += f" WHERE {where}"
|
|
1922
|
+
|
|
1923
|
+
if group_by:
|
|
1924
|
+
sql += f" GROUP BY {group_by}"
|
|
1925
|
+
|
|
1926
|
+
if order_by:
|
|
1927
|
+
sql += f" ORDER BY {order_by}"
|
|
1928
|
+
|
|
1929
|
+
if limit is not None:
|
|
1930
|
+
sql += f" LIMIT {limit}"
|
|
1931
|
+
|
|
1932
|
+
if offset is not None:
|
|
1933
|
+
sql += f" OFFSET {offset}"
|
|
1934
|
+
|
|
1935
|
+
# 実行
|
|
1936
|
+
cursor = self.execute(sql, parameters)
|
|
1937
|
+
|
|
1938
|
+
# カラム名取得
|
|
1939
|
+
if columns is None:
|
|
1940
|
+
pragma_cursor = self.execute(f"PRAGMA table_info({safe_table_name})")
|
|
1941
|
+
col_names = [row[1] for row in pragma_cursor]
|
|
1942
|
+
else:
|
|
1943
|
+
# カラム名からAS句を考慮(case-insensitive)
|
|
1944
|
+
col_names = []
|
|
1945
|
+
for col in columns:
|
|
1946
|
+
parts = re.split(r"\s+as\s+", col, flags=re.IGNORECASE)
|
|
1947
|
+
if len(parts) > 1:
|
|
1948
|
+
col_names.append(parts[-1].strip().strip('"').strip("'"))
|
|
1949
|
+
else:
|
|
1950
|
+
col_names.append(col.strip().strip('"').strip("'"))
|
|
1951
|
+
|
|
1952
|
+
# 結果をdictのリストに変換
|
|
1953
|
+
results = []
|
|
1954
|
+
for row in cursor:
|
|
1955
|
+
results.append(dict(zip(col_names, row)))
|
|
1956
|
+
|
|
1957
|
+
return results
|
|
1958
|
+
|
|
1959
|
+
# ==================== Utility Functions ====================
|
|
1960
|
+
|
|
1961
|
+
def vacuum(self) -> None:
|
|
1962
|
+
"""
|
|
1963
|
+
データベースを最適化(VACUUM実行)
|
|
1964
|
+
|
|
1965
|
+
削除されたレコードの領域を回収し、データベースファイルを最適化。
|
|
1966
|
+
|
|
1967
|
+
Example:
|
|
1968
|
+
>>> db.vacuum()
|
|
1969
|
+
"""
|
|
1970
|
+
self.execute("VACUUM")
|
|
1971
|
+
|
|
1972
|
+
def get_db_size(self) -> int:
|
|
1973
|
+
"""
|
|
1974
|
+
データベースファイルのサイズを取得(バイト単位)
|
|
1975
|
+
|
|
1976
|
+
Returns:
|
|
1977
|
+
データベースファイルのサイズ
|
|
1978
|
+
|
|
1979
|
+
Example:
|
|
1980
|
+
>>> size = db.get_db_size()
|
|
1981
|
+
>>> print(f"DB size: {size / 1024 / 1024:.2f} MB")
|
|
1982
|
+
"""
|
|
1983
|
+
import os
|
|
1984
|
+
|
|
1985
|
+
return os.path.getsize(self._db_path)
|
|
1986
|
+
|
|
1987
|
+
def export_table_to_dict(self, table_name: str) -> list[dict]:
|
|
1988
|
+
"""
|
|
1989
|
+
テーブル全体をdictのリストとして取得
|
|
1990
|
+
|
|
1991
|
+
Args:
|
|
1992
|
+
table_name: テーブル名
|
|
1993
|
+
|
|
1994
|
+
Returns:
|
|
1995
|
+
全レコードのリスト
|
|
1996
|
+
|
|
1997
|
+
Example:
|
|
1998
|
+
>>> all_users = db.export_table_to_dict("users")
|
|
1999
|
+
"""
|
|
2000
|
+
return self.query_with_pagination(table_name=table_name)
|
|
2001
|
+
|
|
2002
|
+
def import_from_dict_list(self, table_name: str, data_list: list[dict]) -> int:
|
|
2003
|
+
"""
|
|
2004
|
+
dictのリストからテーブルに一括挿入
|
|
2005
|
+
|
|
2006
|
+
Args:
|
|
2007
|
+
table_name: テーブル名
|
|
2008
|
+
data_list: 挿入するデータのリスト
|
|
2009
|
+
|
|
2010
|
+
Returns:
|
|
2011
|
+
挿入された行数
|
|
2012
|
+
|
|
2013
|
+
Example:
|
|
2014
|
+
>>> users = [
|
|
2015
|
+
... {"name": "Alice", "age": 25},
|
|
2016
|
+
... {"name": "Bob", "age": 30}
|
|
2017
|
+
... ]
|
|
2018
|
+
>>> count = db.import_from_dict_list("users", users)
|
|
2019
|
+
"""
|
|
2020
|
+
if not data_list:
|
|
2021
|
+
return 0
|
|
2022
|
+
|
|
2023
|
+
safe_table_name = self._sanitize_identifier(table_name)
|
|
2024
|
+
|
|
2025
|
+
# 最初のdictからカラム名を取得
|
|
2026
|
+
columns = list(data_list[0].keys())
|
|
2027
|
+
safe_columns = [self._sanitize_identifier(col) for col in columns]
|
|
2028
|
+
placeholders = ", ".join(["?"] * len(columns))
|
|
2029
|
+
columns_sql = ", ".join(safe_columns)
|
|
2030
|
+
sql = f"INSERT INTO {safe_table_name} ({columns_sql}) VALUES ({placeholders})" # nosec
|
|
2031
|
+
|
|
2032
|
+
# 各dictから値を抽出
|
|
2033
|
+
parameters_list = []
|
|
2034
|
+
for data in data_list:
|
|
2035
|
+
values = [data.get(col) for col in columns]
|
|
2036
|
+
parameters_list.append(tuple(values))
|
|
2037
|
+
|
|
2038
|
+
self.execute_many(sql, parameters_list)
|
|
2039
|
+
return len(data_list)
|
|
2040
|
+
|
|
2041
|
+
def get_last_insert_rowid(self) -> int:
|
|
2042
|
+
"""
|
|
2043
|
+
最後に挿入されたROWIDを取得
|
|
2044
|
+
|
|
2045
|
+
Returns:
|
|
2046
|
+
最後に挿入されたROWID
|
|
2047
|
+
|
|
2048
|
+
Example:
|
|
2049
|
+
>>> db.sql_insert("users", {"name": "Alice"})
|
|
2050
|
+
>>> rowid = db.get_last_insert_rowid()
|
|
2051
|
+
"""
|
|
2052
|
+
return self._connection.last_insert_rowid()
|
|
2053
|
+
|
|
2054
|
+
def pragma(self, pragma_name: str, value: Any = None) -> Any:
|
|
2055
|
+
"""
|
|
2056
|
+
PRAGMA設定の取得/設定
|
|
2057
|
+
|
|
2058
|
+
Args:
|
|
2059
|
+
pragma_name: PRAGMA名
|
|
2060
|
+
value: 設定値(Noneの場合は取得のみ)
|
|
2061
|
+
|
|
2062
|
+
Returns:
|
|
2063
|
+
valueがNoneの場合は現在の値、そうでない場合はNone
|
|
2064
|
+
|
|
2065
|
+
Example:
|
|
2066
|
+
>>> # 取得
|
|
2067
|
+
>>> mode = db.pragma("journal_mode")
|
|
2068
|
+
|
|
2069
|
+
>>> # 設定
|
|
2070
|
+
>>> db.pragma("foreign_keys", 1)
|
|
2071
|
+
"""
|
|
2072
|
+
# Whitelist of allowed PRAGMA commands for security
|
|
2073
|
+
ALLOWED_PRAGMAS = {
|
|
2074
|
+
"foreign_keys",
|
|
2075
|
+
"journal_mode",
|
|
2076
|
+
"synchronous",
|
|
2077
|
+
"cache_size",
|
|
2078
|
+
"temp_store",
|
|
2079
|
+
"locking_mode",
|
|
2080
|
+
"auto_vacuum",
|
|
2081
|
+
"page_size",
|
|
2082
|
+
"encoding",
|
|
2083
|
+
"user_version",
|
|
2084
|
+
"schema_version",
|
|
2085
|
+
"wal_autocheckpoint",
|
|
2086
|
+
"busy_timeout",
|
|
2087
|
+
"query_only",
|
|
2088
|
+
"recursive_triggers",
|
|
2089
|
+
"secure_delete",
|
|
2090
|
+
"table_info",
|
|
2091
|
+
"index_list",
|
|
2092
|
+
"index_info",
|
|
2093
|
+
"database_list",
|
|
2094
|
+
}
|
|
2095
|
+
|
|
2096
|
+
if pragma_name not in ALLOWED_PRAGMAS:
|
|
2097
|
+
raise ValueError(f"PRAGMA '{pragma_name}' is not allowed. Allowed: {', '.join(sorted(ALLOWED_PRAGMAS))}")
|
|
2098
|
+
|
|
2099
|
+
if value is None:
|
|
2100
|
+
cursor = self.execute(f"PRAGMA {pragma_name}")
|
|
2101
|
+
result = cursor.fetchone()
|
|
2102
|
+
return result[0] if result else None
|
|
2103
|
+
else:
|
|
2104
|
+
# Validate value is safe (int, float, or simple string)
|
|
2105
|
+
if not isinstance(value, (int, float, str)):
|
|
2106
|
+
raise ValueError(f"PRAGMA value must be int, float, or str, got {type(value).__name__}")
|
|
2107
|
+
|
|
2108
|
+
# For string values, validate to prevent SQL injection
|
|
2109
|
+
if isinstance(value, str):
|
|
2110
|
+
# Only allow alphanumeric, underscore, dash, and dots for string values
|
|
2111
|
+
if not re.match(r"^[\w\-\.]+$", value):
|
|
2112
|
+
raise ValueError(
|
|
2113
|
+
"PRAGMA string value must contain only alphanumeric, underscore, dash, or dot characters"
|
|
2114
|
+
)
|
|
2115
|
+
value_str = f"'{value}'"
|
|
2116
|
+
else:
|
|
2117
|
+
value_str = str(value)
|
|
2118
|
+
|
|
2119
|
+
self.execute(f"PRAGMA {pragma_name} = {value_str}")
|
|
2120
|
+
return None
|
|
2121
|
+
|
|
2122
|
+
# ==================== Transaction Control ====================
|
|
2123
|
+
|
|
2124
|
+
def begin_transaction(self) -> None:
|
|
2125
|
+
"""
|
|
2126
|
+
トランザクションを開始
|
|
2127
|
+
|
|
2128
|
+
Note:
|
|
2129
|
+
SQLiteはネストされたトランザクションをサポートしていません。
|
|
2130
|
+
既にトランザクション中の場合、NanaSQLiteTransactionErrorが発生します。
|
|
2131
|
+
|
|
2132
|
+
Raises:
|
|
2133
|
+
NanaSQLiteTransactionError: 既にトランザクション中の場合
|
|
2134
|
+
NanaSQLiteConnectionError: 接続が閉じられている場合
|
|
2135
|
+
NanaSQLiteDatabaseError: トランザクション開始に失敗した場合
|
|
2136
|
+
|
|
2137
|
+
Example:
|
|
2138
|
+
>>> db.begin_transaction()
|
|
2139
|
+
>>> try:
|
|
2140
|
+
... db.sql_insert("users", {"name": "Alice"})
|
|
2141
|
+
... db.sql_insert("users", {"name": "Bob"})
|
|
2142
|
+
... db.commit()
|
|
2143
|
+
... except:
|
|
2144
|
+
... db.rollback()
|
|
2145
|
+
"""
|
|
2146
|
+
self._check_connection()
|
|
2147
|
+
|
|
2148
|
+
if self._in_transaction:
|
|
2149
|
+
raise NanaSQLiteTransactionError(
|
|
2150
|
+
"Transaction already in progress. "
|
|
2151
|
+
"SQLite does not support nested transactions. "
|
|
2152
|
+
"Please commit or rollback the current transaction first."
|
|
2153
|
+
)
|
|
2154
|
+
|
|
2155
|
+
try:
|
|
2156
|
+
self.execute("BEGIN IMMEDIATE")
|
|
2157
|
+
self._in_transaction = True
|
|
2158
|
+
self._transaction_depth = 1
|
|
2159
|
+
except Exception as e:
|
|
2160
|
+
raise NanaSQLiteDatabaseError(
|
|
2161
|
+
f"Failed to begin transaction: {e}", original_error=e if isinstance(e, apsw.Error) else None
|
|
2162
|
+
) from e
|
|
2163
|
+
|
|
2164
|
+
def commit(self) -> None:
|
|
2165
|
+
"""
|
|
2166
|
+
トランザクションをコミット
|
|
2167
|
+
|
|
2168
|
+
Raises:
|
|
2169
|
+
NanaSQLiteTransactionError: トランザクション外でコミットを試みた場合
|
|
2170
|
+
NanaSQLiteConnectionError: 接続が閉じられている場合
|
|
2171
|
+
NanaSQLiteDatabaseError: コミットに失敗した場合
|
|
2172
|
+
"""
|
|
2173
|
+
self._check_connection()
|
|
2174
|
+
|
|
2175
|
+
if not self._in_transaction:
|
|
2176
|
+
raise NanaSQLiteTransactionError(
|
|
2177
|
+
"No transaction in progress. Call begin_transaction() first or use the transaction() context manager."
|
|
2178
|
+
)
|
|
2179
|
+
|
|
2180
|
+
try:
|
|
2181
|
+
self.execute("COMMIT")
|
|
2182
|
+
self._in_transaction = False
|
|
2183
|
+
self._transaction_depth = 0
|
|
2184
|
+
except Exception as e:
|
|
2185
|
+
# コミット失敗時は状態を維持(ロールバックが必要)
|
|
2186
|
+
raise NanaSQLiteDatabaseError(
|
|
2187
|
+
f"Failed to commit transaction: {e}", original_error=e if isinstance(e, apsw.Error) else None
|
|
2188
|
+
) from e
|
|
2189
|
+
|
|
2190
|
+
def rollback(self) -> None:
|
|
2191
|
+
"""
|
|
2192
|
+
トランザクションをロールバック
|
|
2193
|
+
|
|
2194
|
+
Raises:
|
|
2195
|
+
NanaSQLiteTransactionError: トランザクション外でロールバックを試みた場合
|
|
2196
|
+
NanaSQLiteConnectionError: 接続が閉じられている場合
|
|
2197
|
+
NanaSQLiteDatabaseError: ロールバックに失敗した場合
|
|
2198
|
+
"""
|
|
2199
|
+
self._check_connection()
|
|
2200
|
+
|
|
2201
|
+
if not self._in_transaction:
|
|
2202
|
+
raise NanaSQLiteTransactionError(
|
|
2203
|
+
"No transaction in progress. Call begin_transaction() first or use the transaction() context manager."
|
|
2204
|
+
)
|
|
2205
|
+
|
|
2206
|
+
try:
|
|
2207
|
+
self.execute("ROLLBACK")
|
|
2208
|
+
self._in_transaction = False
|
|
2209
|
+
self._transaction_depth = 0
|
|
2210
|
+
except Exception as e:
|
|
2211
|
+
# ロールバック失敗は深刻なので状態をリセット
|
|
2212
|
+
self._in_transaction = False
|
|
2213
|
+
self._transaction_depth = 0
|
|
2214
|
+
raise NanaSQLiteDatabaseError(
|
|
2215
|
+
f"Failed to rollback transaction: {e}", original_error=e if isinstance(e, apsw.Error) else None
|
|
2216
|
+
) from e
|
|
2217
|
+
|
|
2218
|
+
def in_transaction(self) -> bool:
|
|
2219
|
+
"""
|
|
2220
|
+
現在トランザクション中かどうかを返す
|
|
2221
|
+
|
|
2222
|
+
Returns:
|
|
2223
|
+
bool: トランザクション中の場合True
|
|
2224
|
+
|
|
2225
|
+
Example:
|
|
2226
|
+
>>> db.begin_transaction()
|
|
2227
|
+
>>> print(db.in_transaction()) # True
|
|
2228
|
+
>>> db.commit()
|
|
2229
|
+
>>> print(db.in_transaction()) # False
|
|
2230
|
+
"""
|
|
2231
|
+
return self._in_transaction
|
|
2232
|
+
|
|
2233
|
+
def transaction(self):
|
|
2234
|
+
"""
|
|
2235
|
+
トランザクションのコンテキストマネージャ
|
|
2236
|
+
|
|
2237
|
+
コンテキストマネージャ内で例外が発生しない場合は自動的にコミット、
|
|
2238
|
+
例外が発生した場合は自動的にロールバックします。
|
|
2239
|
+
|
|
2240
|
+
Raises:
|
|
2241
|
+
NanaSQLiteTransactionError: 既にトランザクション中の場合
|
|
2242
|
+
|
|
2243
|
+
Example:
|
|
2244
|
+
>>> with db.transaction():
|
|
2245
|
+
... db.sql_insert("users", {"name": "Alice"})
|
|
2246
|
+
... db.sql_insert("users", {"name": "Bob"})
|
|
2247
|
+
... # 自動的にコミット、例外時はロールバック
|
|
2248
|
+
"""
|
|
2249
|
+
return _TransactionContext(self)
|
|
2250
|
+
|
|
2251
|
+
def table(
|
|
2252
|
+
self,
|
|
2253
|
+
table_name: str,
|
|
2254
|
+
cache_strategy: CacheType | Literal["unbounded", "lru"] | None = None,
|
|
2255
|
+
cache_size: int | None = None,
|
|
2256
|
+
):
|
|
2257
|
+
"""
|
|
2258
|
+
サブテーブル用のNanaSQLiteインスタンスを取得
|
|
2259
|
+
|
|
2260
|
+
新しいインスタンスを作成しますが、SQLite接続とロックは共有します。
|
|
2261
|
+
これにより、複数のテーブルインスタンスが同じ接続を使用して
|
|
2262
|
+
スレッドセーフに動作します。
|
|
2263
|
+
|
|
2264
|
+
Args:
|
|
2265
|
+
table_name: テーブル名
|
|
2266
|
+
cache_strategy: このテーブル用のキャッシュ戦略 (デフォルト: 親と同じ)
|
|
2267
|
+
cache_size: このテーブル用のキャッシュサイズ (デフォルト: 親と同じ)
|
|
2268
|
+
|
|
2269
|
+
⚠️ 重要な注意事項:
|
|
2270
|
+
- 同じテーブルに対して複数のインスタンスを作成しないでください
|
|
2271
|
+
各インスタンスは独立したキャッシュを持つため、キャッシュ不整合が発生します
|
|
2272
|
+
- 推奨: テーブルインスタンスを変数に保存して再利用してください
|
|
2273
|
+
|
|
2274
|
+
非推奨:
|
|
2275
|
+
sub1 = db.table("users")
|
|
2276
|
+
sub2 = db.table("users") # キャッシュ不整合の原因
|
|
2277
|
+
|
|
2278
|
+
推奨:
|
|
2279
|
+
users_db = db.table("users")
|
|
2280
|
+
# users_dbを使い回す
|
|
2281
|
+
|
|
2282
|
+
:param table_name: テーブル名
|
|
2283
|
+
:return NanaSQLite: 新しいテーブルインスタンス
|
|
2284
|
+
|
|
2285
|
+
Raises:
|
|
2286
|
+
NanaSQLiteConnectionError: 接続が閉じられている場合
|
|
2287
|
+
|
|
2288
|
+
Example:
|
|
2289
|
+
>>> with NanaSQLite("app.db", table="main") as main_db:
|
|
2290
|
+
... users_db = main_db.table("users")
|
|
2291
|
+
... products_db = main_db.table("products")
|
|
2292
|
+
... users_db["user1"] = {"name": "Alice"}
|
|
2293
|
+
... products_db["prod1"] = {"name": "Laptop"}
|
|
2294
|
+
"""
|
|
2295
|
+
self._check_connection()
|
|
2296
|
+
|
|
2297
|
+
# 指定がなければデフォルト(UNBOUNDED)
|
|
2298
|
+
strat = cache_strategy if cache_strategy is not None else CacheType.UNBOUNDED
|
|
2299
|
+
size = cache_size
|
|
2300
|
+
|
|
2301
|
+
child = NanaSQLite(
|
|
2302
|
+
self._db_path,
|
|
2303
|
+
table=table_name,
|
|
2304
|
+
cache_strategy=strat,
|
|
2305
|
+
cache_size=size,
|
|
2306
|
+
_shared_connection=self._connection,
|
|
2307
|
+
_shared_lock=self._lock,
|
|
2308
|
+
)
|
|
2309
|
+
|
|
2310
|
+
# If the parent is the connection owner, the child is not.
|
|
2311
|
+
# This ensures only one instance (the owner) attempts to close the connection.
|
|
2312
|
+
if self._is_connection_owner:
|
|
2313
|
+
child._is_connection_owner = False
|
|
2314
|
+
|
|
2315
|
+
# 子インスタンスを追跡 (WeakSetに直接オブジェクトを追加すると、WeakSetが弱参照を保持する)
|
|
2316
|
+
self._child_instances.add(child)
|
|
2317
|
+
|
|
2318
|
+
return child
|
|
2319
|
+
|
|
2320
|
+
|
|
2321
|
+
class _TransactionContext:
|
|
2322
|
+
"""トランザクションのコンテキストマネージャ"""
|
|
2323
|
+
|
|
2324
|
+
def __init__(self, db: NanaSQLite):
|
|
2325
|
+
self.db = db
|
|
2326
|
+
|
|
2327
|
+
def __enter__(self):
|
|
2328
|
+
self.db.begin_transaction()
|
|
2329
|
+
return self.db
|
|
2330
|
+
|
|
2331
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
2332
|
+
if exc_type is None:
|
|
2333
|
+
self.db.commit()
|
|
2334
|
+
else:
|
|
2335
|
+
self.db.rollback()
|
|
2336
|
+
return False
|