nanasqlite 1.3.3.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nanasqlite/core.py ADDED
@@ -0,0 +1,2336 @@
1
+ """
2
+ NanaSQLite: APSW SQLite-backed dict wrapper with memory caching.
3
+
4
+ 通常のPython dictをラップし、操作時にSQLite永続化処理を行う。
5
+ - 書き込み: 即時SQLiteへ永続化
6
+ - 読み込み: デフォルトは遅延ロード(使用時)、一度読み込んだらメモリ管理
7
+ - 一括ロード: bulk_load=Trueで起動時に全データをメモリに展開
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import json
13
+ import logging
14
+ import os
15
+ import re
16
+ import threading
17
+ import warnings
18
+ import weakref
19
+ from collections.abc import Iterator, MutableMapping
20
+ from typing import Any, Literal
21
+
22
+ import apsw
23
+
24
+ try:
25
+ from cryptography.fernet import Fernet
26
+ from cryptography.hazmat.primitives.ciphers.aead import AESGCM, ChaCha20Poly1305
27
+ HAS_CRYPTOGRAPHY = True
28
+ except ImportError:
29
+ HAS_CRYPTOGRAPHY = False
30
+
31
+ from .cache import CacheStrategy, CacheType, create_cache
32
+ from .exceptions import (
33
+ NanaSQLiteClosedError,
34
+ NanaSQLiteConnectionError,
35
+ NanaSQLiteDatabaseError,
36
+ NanaSQLiteTransactionError,
37
+ NanaSQLiteValidationError,
38
+ )
39
+ from .sql_utils import fast_validate_sql_chars, sanitize_sql_for_function_scan
40
+
41
+ # 識別子バリデーション用の正規表現パターン(英数字とアンダースコアのみ、数字で開始しない)
42
+ IDENTIFIER_PATTERN = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
43
+
44
+ logger = logging.getLogger(__name__)
45
+
46
+ # Optional fast JSON (orjson)
47
+ try:
48
+ import orjson # type: ignore
49
+
50
+ HAS_ORJSON = True
51
+ except Exception:
52
+ HAS_ORJSON = False
53
+
54
+
55
+ class NanaSQLite(MutableMapping):
56
+ """
57
+ APSW SQLite-backed dict wrapper with Security and Connection Enhancements (v1.2.0).
58
+ (APSW SQLiteをバックエンドとした、セキュリティ・接続管理強化版の辞書型ラッパー (v1.2.0))
59
+
60
+ Internally maintains a Python dict and synchronizes with SQLite during operations.
61
+ In v1.2.0, enhanced dynamic SQL validation, ReDoS protection, and strict connection management are introduced.
62
+
63
+ 内部でPython dictを保持し、操作時にSQLiteとの同期を行います。
64
+ v1.2.0では、動的SQLのバリデーション強化、ReDoS対策、および厳格な接続管理が導入されています。
65
+
66
+ Args:
67
+ db_path: SQLiteデータベースファイルのパス
68
+ table: 使用するテーブル名 (デフォルト: "data")
69
+ bulk_load: Trueの場合、初期化時に全データをメモリに読み込む
70
+ strict_sql_validation: Trueの場合、未許可の関数等を含むクエリを拒否 (v1.2.0)
71
+ max_clause_length: SQL句の最大長(ReDoS対策、v1.2.0)
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ db_path: str,
77
+ table: str = "data",
78
+ bulk_load: bool = False,
79
+ optimize: bool = True,
80
+ cache_size_mb: int = 64,
81
+ strict_sql_validation: bool = True,
82
+ allowed_sql_functions: list[str] | None = None,
83
+ forbidden_sql_functions: list[str] | None = None,
84
+ max_clause_length: int | None = 1000,
85
+ cache_strategy: CacheType | Literal["unbounded", "lru", "ttl"] = CacheType.UNBOUNDED,
86
+ cache_size: int | None = None,
87
+ cache_ttl: float | None = None,
88
+ cache_persistence_ttl: bool = False,
89
+ encryption_key: str | bytes | None = None,
90
+ encryption_mode: Literal["aes-gcm", "chacha20", "fernet"] = "aes-gcm",
91
+ _shared_connection: apsw.Connection | None = None,
92
+ _shared_lock: threading.RLock | None = None,
93
+ ):
94
+ """
95
+ Args:
96
+ db_path: SQLiteデータベースファイルのパス
97
+ table: 使用するテーブル名 (デフォルト: "data")
98
+ bulk_load: Trueの場合、初期化時に全データをメモリに読み込む
99
+ optimize: Trueの場合、WALモードなど高速化設定を適用
100
+ cache_size_mb: SQLiteキャッシュサイズ(MB)、デフォルト64MB
101
+ strict_sql_validation: Trueの場合、未許可の関数等を含むクエリを拒否
102
+ allowed_sql_functions: 追加で許可するSQL関数のリスト
103
+ forbidden_sql_functions: 明示的に禁止するSQL関数のリスト
104
+ max_clause_length: SQL句の最大長(ReDoS対策)。Noneで制限なし
105
+ _shared_connection: 内部用:共有する接続(table()メソッドで使用)
106
+ _shared_lock: 内部用:共有するロック(table()メソッドで使用)
107
+ """
108
+ self._db_path: str = db_path
109
+ self._table: str = table
110
+
111
+ # Encryption setup
112
+ self._encryption_key = encryption_key
113
+ self._encryption_mode = encryption_mode
114
+ self._fernet: Fernet | None = None
115
+ self._aead: AESGCM | ChaCha20Poly1305 | None = None
116
+
117
+ if encryption_key:
118
+ if not HAS_CRYPTOGRAPHY:
119
+ raise ImportError(
120
+ "Encryption requires the 'cryptography' library. "
121
+ "Install it with: pip install nanasqlite[encryption]"
122
+ )
123
+ # Support both str (base64) and bytes
124
+ key_bytes: bytes = encryption_key.encode("utf-8") if isinstance(encryption_key, str) else encryption_key
125
+
126
+ if encryption_mode == "fernet":
127
+ self._fernet = Fernet(key_bytes)
128
+ elif encryption_mode == "aes-gcm":
129
+ self._aead = AESGCM(key_bytes)
130
+ elif encryption_mode == "chacha20":
131
+ self._aead = ChaCha20Poly1305(key_bytes)
132
+ else:
133
+ raise ValueError(f"Unsupported encryption_mode: {encryption_mode}")
134
+
135
+ # Setup Persistence TTL callback if enabled
136
+ on_expire = None
137
+ if (cache_strategy == CacheType.TTL or cache_strategy == "ttl") and cache_persistence_ttl:
138
+
139
+ def _expire_callback(key: str, value: Any) -> None:
140
+ try:
141
+ # Use a new or shared connection to delete from DB
142
+ # Implementation detail: we need to be careful with locks
143
+ self._delete_from_db_on_expire(key)
144
+ except Exception as e:
145
+ logger.error(f"Failed to delete expired key '{key}' from DB: {e}")
146
+
147
+ on_expire = _expire_callback
148
+
149
+ self._cache: CacheStrategy = create_cache(cache_strategy, cache_size, ttl=cache_ttl, on_expire=on_expire)
150
+ self._data = self._cache.get_data()
151
+ self._lru_mode = (
152
+ (cache_strategy == CacheType.LRU) or (cache_strategy == "lru") or
153
+ (cache_strategy == CacheType.TTL) or (cache_strategy == "ttl")
154
+ )
155
+
156
+ if not self._lru_mode:
157
+ # Unbounded 以外のモードでは内部辞書の直接参照を使用しない場合があるが、
158
+ # 現状の設計では _cached_keys を通じて存在チェックを行っている
159
+ self._cached_keys = self._cache._cached_keys # type: ignore
160
+ else:
161
+ # LRU/TTL モードでは、データ保持自体が存在の証
162
+ self._cached_keys = self._data # type: ignore
163
+
164
+ self._all_loaded: bool = False # 全データ読み込み済みフラグ
165
+
166
+ # セキュリティ設定
167
+ self.strict_sql_validation = strict_sql_validation
168
+ self.allowed_sql_functions = set(allowed_sql_functions or [])
169
+ self.forbidden_sql_functions = set(forbidden_sql_functions or [])
170
+ self.max_clause_length = max_clause_length
171
+
172
+ # デフォルトで許可されるSQL関数
173
+ self._default_allowed_functions = {
174
+ "COUNT",
175
+ "SUM",
176
+ "AVG",
177
+ "MIN",
178
+ "MAX",
179
+ "ABS",
180
+ "UPPER",
181
+ "LOWER",
182
+ "LENGTH",
183
+ "ROUND",
184
+ "COALESCE",
185
+ "IFNULL",
186
+ "NULLIF",
187
+ "STRFTIME",
188
+ "DATE",
189
+ "TIME",
190
+ "DATETIME",
191
+ "JULIANDAY",
192
+ }
193
+
194
+ # トランザクション状態管理
195
+ self._in_transaction: bool = False # トランザクション中かどうか
196
+ self._transaction_depth: int = 0 # ネストレベル(警告用)
197
+
198
+ # 子インスタンスの追跡(リソース管理用)
199
+ self._child_instances = weakref.WeakSet() # WeakSetによる弱参照追跡(死んだ参照は自動的にクリーンアップ)
200
+ self._is_closed: bool = False # 接続が閉じられたか
201
+ self._parent_closed: bool = False # 親接続が閉じられたか
202
+
203
+ # 接続とロックの共有または新規作成
204
+ if _shared_connection is not None:
205
+ # 接続を共有(table()メソッドから呼ばれた場合)
206
+ self._connection: apsw.Connection = _shared_connection
207
+ self._lock = _shared_lock if _shared_lock is not None else threading.RLock()
208
+ self._is_connection_owner = False # 接続の所有者ではない
209
+ else:
210
+ # 新規接続を作成(通常の初期化)
211
+ try:
212
+ self._connection: apsw.Connection = apsw.Connection(db_path)
213
+ except apsw.Error as e:
214
+ raise NanaSQLiteConnectionError(f"Failed to connect to database: {e}") from e
215
+ self._lock = threading.RLock()
216
+ self._is_connection_owner = True # 接続の所有者
217
+
218
+ # 高速化設定(接続の所有者のみ)
219
+ if optimize:
220
+ self._apply_optimizations(cache_size_mb)
221
+
222
+ # テーブル作成
223
+ with self._lock:
224
+ self._connection.execute(f"""
225
+ CREATE TABLE IF NOT EXISTS {self._table} (
226
+ key TEXT PRIMARY KEY,
227
+ value TEXT
228
+ )
229
+ """)
230
+
231
+ # 一括ロード
232
+ if bulk_load:
233
+ self.load_all()
234
+
235
+ def _apply_optimizations(self, cache_size_mb: int = 64) -> None:
236
+ """
237
+ APSWの高速化設定を適用
238
+
239
+ - WALモード: 書き込み並行性向上、30ms+ -> 1ms以下に改善
240
+ - synchronous=NORMAL: 安全性を保ちつつ高速化
241
+ - mmap: メモリマップドI/Oで読み込み高速化
242
+ - cache_size: SQLiteのメモリキャッシュ増加
243
+ - temp_store=MEMORY: 一時テーブルをメモリに
244
+ """
245
+ cursor = self._connection.cursor()
246
+
247
+ # WALモード(Write-Ahead Logging)- 書き込み高速化の核心
248
+ cursor.execute("PRAGMA journal_mode = WAL")
249
+
250
+ # synchronous=NORMAL: WALモードでは安全かつ高速
251
+ cursor.execute("PRAGMA synchronous = NORMAL")
252
+
253
+ # メモリマップドI/O(256MB)- 読み込み高速化
254
+ cursor.execute("PRAGMA mmap_size = 268435456")
255
+
256
+ # キャッシュサイズ(負の値=KB単位)
257
+ cache_kb = cache_size_mb * 1024
258
+ cursor.execute(f"PRAGMA cache_size = -{cache_kb}")
259
+
260
+ # 一時テーブルをメモリに
261
+ cursor.execute("PRAGMA temp_store = MEMORY")
262
+
263
+ # ページサイズ最適化(新規DBのみ効果あり)
264
+ cursor.execute("PRAGMA page_size = 4096")
265
+
266
+ @staticmethod
267
+ def _sanitize_identifier(identifier: str) -> str:
268
+ """
269
+ SQLiteの識別子(テーブル名、カラム名など)を検証
270
+
271
+ Args:
272
+ identifier: 検証する識別子
273
+
274
+ Returns:
275
+ 検証済み識別子(ダブルクォートで囲まれる)
276
+
277
+ Raises:
278
+ NanaSQLiteValidationError: 識別子が無効な場合
279
+
280
+ Note:
281
+ SQLiteの識別子は以下をサポート:
282
+ - 英数字とアンダースコア
283
+ - 数字で開始しない
284
+ - SQLキーワードも引用符で囲めば使用可能
285
+ """
286
+ if not identifier:
287
+ raise NanaSQLiteValidationError("Identifier cannot be empty")
288
+
289
+ # 基本的な検証: 英数字とアンダースコアのみ許可
290
+ if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", identifier):
291
+ raise NanaSQLiteValidationError(
292
+ f"Invalid identifier '{identifier}': must start with letter or underscore "
293
+ "and contain only alphanumeric characters and underscores"
294
+ )
295
+
296
+ # SQLiteではダブルクォートで囲むことで識別子をエスケープ
297
+ return f'"{identifier}"'
298
+
299
+ # ==================== Private Methods ====================
300
+
301
+ def __hash__(self):
302
+ # MutableMapping inhibits hashing by default because it's mutable.
303
+ # However, we need identity-based hashing to track instances in WeakSet.
304
+ # This is safe as long as we don't rely on content-based hashing in sets.
305
+ # NOTE: This technically violates the rule that a==b implies hash(a)==hash(b),
306
+ # because __eq__ implements content equivalence while __hash__ implements identity.
307
+ # This is an intentional design choice to support WeakSet management while providing
308
+ # convenient dict-like equality comparisons.
309
+ return id(self)
310
+
311
+ def __eq__(self, other):
312
+ """
313
+ 辞書のような等価性比較を実装
314
+
315
+ 他のマッピング(dictやMutableMapping)との比較では内容ベースの比較を行い、
316
+ それ以外では同一性(is)での比較を行う。
317
+
318
+ Args:
319
+ other: 比較対象のオブジェクト
320
+
321
+ Returns:
322
+ bool: 等価な場合True、そうでない場合False
323
+
324
+ Raises:
325
+ NanaSQLiteClosedError: 接続が閉じられている場合
326
+ """
327
+ if isinstance(other, (dict, MutableMapping)):
328
+ # Ensure the connection is open; propagate NanaSQLiteClosedError if not.
329
+ self._check_connection()
330
+ return dict(self.items()) == dict(other.items())
331
+ return self is other
332
+
333
+ def _check_connection(self) -> None:
334
+ """
335
+ 接続が有効かチェック
336
+
337
+ Raises:
338
+ NanaSQLiteClosedError: 接続が閉じられている、または親が閉じられている場合
339
+ """
340
+ if self._is_closed:
341
+ raise NanaSQLiteClosedError(f"Database connection is closed (table: '{self._table}').")
342
+ if self._parent_closed:
343
+ raise NanaSQLiteClosedError(
344
+ f"Parent database connection is closed (table: '{self._table}'). "
345
+ "If you obtained this instance via .table(), ensure the primary "
346
+ "NanaSQLite instance remains open during usage."
347
+ )
348
+
349
+ def _validate_expression(
350
+ self,
351
+ expr: str | None,
352
+ strict: bool | None = None,
353
+ allowed: list[str] | None = None,
354
+ forbidden: list[str] | None = None,
355
+ override_allowed: bool = False,
356
+ context: Literal["order_by", "group_by", "where", "column"] | None = None,
357
+ ) -> None:
358
+ """
359
+ SQL表現(ORDER BY, GROUP BY, 列名等)を検証。
360
+
361
+ Args:
362
+ expr: 検証するSQL表現
363
+ strict: 強制停止モード。Noneの場合はインスタンス設定を使用。
364
+ allowed: 今回のクエリで追加/置換して許可する関数。
365
+ forbidden: 今回のクエリで明示的に禁止する関数。
366
+ override_allowed: Trueの場合、インスタンス許可設定を無視して今回のallowedのみ参照。
367
+ context: エラーメッセージのコンテキスト ("order_by", "group_by", "where", "column")
368
+
369
+ Raises:
370
+ NanaSQLiteValidationError: strict=True かつ不適切な表現の場合
371
+ UserWarning: strict=False かつ不適切な表現の場合(実行は許可)
372
+ """
373
+ if not expr:
374
+ return
375
+
376
+ # 0. legacy check for SQL injection patterns
377
+ # test_security.py compatibility: raise ValueError for strictly dangerous patterns
378
+ # We use a combined message to satisfy both test_security.py ("Potentially dangerous...")
379
+ # and test_security_additions.py ("Invalid...")
380
+ warning_text = "Potentially dangerous SQL pattern"
381
+
382
+ context_labels = {
383
+ "order_by": "order_by clause",
384
+ "group_by": "group_by clause",
385
+ "where": "where clause",
386
+ "column": "column name",
387
+ }
388
+ label = context_labels.get(context)
389
+
390
+ # Standardize format: "Invalid [label]: [warning_text]" (or "Invalid: [warning_text]" if no label)
391
+ # This satisfies both legacy and new security tests.
392
+ if label:
393
+ full_msg = f"Invalid {label}: {warning_text}"
394
+ else:
395
+ full_msg = f"Invalid: {warning_text}"
396
+
397
+ dangerous_patterns = [
398
+ (r";", full_msg),
399
+ (r"--", full_msg),
400
+ (r"/\*", full_msg),
401
+ (r"\b(DROP|DELETE|UPDATE|INSERT|TRUNCATE|ALTER)\b", full_msg),
402
+ ]
403
+
404
+ # 0.5. Fast character-set validation (ReDoS countermeasure)
405
+ if not fast_validate_sql_chars(str(expr)):
406
+ # If invalid characters are found, we apply strict or warning
407
+ # Note: This is a preventative layer.
408
+ # We use full_msg to maintain compatibility with existing tests expecting "Invalid [label]: ..."
409
+ msg = f"{full_msg} or invalid characters detected."
410
+ if strict or (strict is None and self.strict_sql_validation):
411
+ raise ValueError(msg)
412
+ else:
413
+ warnings.warn(msg, UserWarning, stacklevel=2)
414
+
415
+ for pattern, msg in dangerous_patterns:
416
+ if re.search(pattern, str(expr), re.IGNORECASE):
417
+ # Block highly dangerous patterns in strict mode, but only warn in non-strict
418
+ if strict or (strict is None and self.strict_sql_validation):
419
+ raise ValueError(msg)
420
+ else:
421
+ warnings.warn(msg, UserWarning, stacklevel=2)
422
+
423
+ # 1. 長さ制限 (ReDoS対策)
424
+ max_len = self.max_clause_length
425
+ if max_len and len(expr) > max_len:
426
+ msg = f"SQL expression exceeds maximum length of {max_len} characters."
427
+ if strict or (strict is None and self.strict_sql_validation):
428
+ raise NanaSQLiteValidationError(msg)
429
+ else:
430
+ warnings.warn(msg, UserWarning, stacklevel=2)
431
+
432
+ # 2. 禁止リストの整理 (メソッド指定を優先、なければインスタンス設定)
433
+ forbidden_list = set(forbidden) if forbidden is not None else self.forbidden_sql_functions
434
+ if "*" in forbidden_list:
435
+ msg = "All SQL functions are forbidden for this expression."
436
+ if strict or (strict is None and self.strict_sql_validation):
437
+ raise NanaSQLiteValidationError(msg)
438
+ else:
439
+ warnings.warn(msg, UserWarning, stacklevel=2)
440
+ return
441
+
442
+ # 3. 許可リストの整理
443
+ effective_allowed = set()
444
+ if not override_allowed:
445
+ effective_allowed.update(self._default_allowed_functions)
446
+ effective_allowed.update(self.allowed_sql_functions)
447
+
448
+ if allowed:
449
+ effective_allowed.update(allowed)
450
+
451
+ # 禁止リストに含まれるものは許可から削除
452
+ effective_allowed -= forbidden_list
453
+
454
+ # 4. 関数呼び出しの抽出
455
+ # 文字列リテラルやコメントをマスクした上で関数呼び出しを検索
456
+ # これにより、SELECT 'COUNT(' ... のようなパターンでの誤検知を防ぐ
457
+ sanitized_expr = sanitize_sql_for_function_scan(expr)
458
+ matches = re.findall(r"([a-zA-Z_][a-zA-Z0-9_]*)\s*\(", sanitized_expr)
459
+
460
+ for func in matches:
461
+ func_upper = func.upper()
462
+
463
+ # 明示的に禁止されている場合
464
+ if func_upper in forbidden_list:
465
+ msg = f"SQL function '{func_upper}' is explicitly forbidden."
466
+ if strict or (strict is None and self.strict_sql_validation):
467
+ raise NanaSQLiteValidationError(msg)
468
+ else:
469
+ warnings.warn(msg, UserWarning, stacklevel=2)
470
+ continue
471
+
472
+ # 許可リストにない場合
473
+ if func_upper not in effective_allowed:
474
+ msg = (
475
+ f"SQL function '{func_upper}' is not in the allowed list. "
476
+ "Use 'allowed_sql_functions' to permit it if you trust this function."
477
+ )
478
+ if strict or (strict is None and self.strict_sql_validation):
479
+ raise NanaSQLiteValidationError(msg)
480
+ else:
481
+ warnings.warn(msg, UserWarning, stacklevel=2)
482
+
483
+ def _mark_parent_closed(self) -> None:
484
+ """
485
+ 親インスタンスから呼ばれ、親が閉じられたことをマークする
486
+ """
487
+ self._parent_closed = True
488
+
489
+ def _serialize(self, value: Any) -> bytes | str:
490
+ """シリアライズ (JSON -> Encryption if enabled)"""
491
+ # Use fastest available JSON serializer
492
+ if HAS_ORJSON:
493
+ # orjson returns bytes
494
+ data = orjson.dumps(value)
495
+ json_str = None
496
+ else:
497
+ json_str = json.dumps(value, ensure_ascii=False)
498
+ data = json_str.encode("utf-8")
499
+
500
+ if self._fernet:
501
+ return self._fernet.encrypt(data)
502
+
503
+ if self._aead:
504
+ # Generate 12 bytes nonce
505
+ nonce = os.urandom(12)
506
+ ciphertext = self._aead.encrypt(nonce, data, None)
507
+ # Combine nonce + ciphertext
508
+ return nonce + ciphertext
509
+
510
+ # No encryption: store as TEXT for compatibility/perf (str)
511
+ if HAS_ORJSON:
512
+ # Decode once to keep DB storage as TEXT
513
+ return data.decode("utf-8")
514
+ return json_str
515
+
516
+ def _deserialize(self, value: bytes | str) -> Any:
517
+ """デシリアライズ (Decryption if enabled -> JSON)"""
518
+ if self._fernet:
519
+ decoded = self._fernet.decrypt(value).decode("utf-8")
520
+ if HAS_ORJSON:
521
+ return orjson.loads(decoded)
522
+ return json.loads(decoded)
523
+
524
+ if self._aead:
525
+ if not isinstance(value, bytes):
526
+ # Fallback or manual check if stored as string accidentally
527
+ if HAS_ORJSON:
528
+ return orjson.loads(value)
529
+ return json.loads(value)
530
+
531
+ # Split nonce (12B) and ciphertext
532
+ nonce = value[:12]
533
+ ciphertext = value[12:]
534
+ decoded = self._aead.decrypt(nonce, ciphertext, None).decode("utf-8")
535
+ if HAS_ORJSON:
536
+ return orjson.loads(decoded)
537
+ return json.loads(decoded)
538
+
539
+ # No encryption path
540
+ if HAS_ORJSON:
541
+ return orjson.loads(value)
542
+ return json.loads(value)
543
+
544
+ def _write_to_db(self, key: str, value: Any) -> None:
545
+ """即時書き込み: SQLiteに値を保存"""
546
+ serialized = self._serialize(value)
547
+ with self._lock:
548
+ self._connection.execute(
549
+ f"INSERT OR REPLACE INTO {self._table} (key, value) VALUES (?, ?)", # nosec
550
+ (key, serialized),
551
+ )
552
+
553
+ def _read_from_db(self, key: str) -> Any | None:
554
+ """SQLiteから値を読み込み"""
555
+ with self._lock:
556
+ cursor = self._connection.execute(
557
+ f"SELECT value FROM {self._table} WHERE key = ?", # nosec
558
+ (key,),
559
+ )
560
+ row = cursor.fetchone()
561
+ if row is None:
562
+ return None
563
+ return self._deserialize(row[0])
564
+
565
+ def _delete_from_db(self, key: str) -> None:
566
+ """SQLiteから値を削除"""
567
+ with self._lock:
568
+ self._connection.execute(
569
+ f"DELETE FROM {self._table} WHERE key = ?", # nosec
570
+ (key,),
571
+ )
572
+
573
+ def _get_all_keys_from_db(self) -> list:
574
+ """SQLiteから全キーを取得"""
575
+ with self._lock:
576
+ cursor = self._connection.execute(
577
+ f"SELECT key FROM {self._table}" # nosec
578
+ )
579
+ return [row[0] for row in cursor]
580
+
581
+ def _ensure_cached(self, key: str) -> bool:
582
+ """
583
+ キーがキャッシュにない場合、DBから読み込む(遅延ロード)
584
+ Returns: キーが存在するかどうか
585
+ """
586
+ # FAST PATH for default Unbounded mode
587
+ if not self._lru_mode:
588
+ if key in self._cached_keys:
589
+ return key in self._data
590
+ else:
591
+ if key in self._data:
592
+ return True
593
+
594
+ # DBから読み込み
595
+ value = self._read_from_db(key)
596
+
597
+ if value is not None:
598
+ if self._lru_mode or (hasattr(self._cache, "_max_size") and self._cache._max_size):
599
+ self._cache.set(key, value)
600
+ else:
601
+ self._data[key] = value
602
+ self._cached_keys.add(key)
603
+ return True
604
+
605
+ # Value is None (not in DB)
606
+ if not self._lru_mode:
607
+ self._cached_keys.add(key)
608
+ return False
609
+
610
+ # ==================== Dict Interface ====================
611
+
612
+ def __getitem__(self, key: str) -> Any:
613
+ """dict[key] - 遅延ロード後、メモリから取得"""
614
+ if self._ensure_cached(key):
615
+ # LRU updates order even on __getitem__
616
+ if self._lru_mode:
617
+ return self._cache.get(key)
618
+ return self._data[key]
619
+ raise KeyError(key)
620
+
621
+ def __setitem__(self, key: str, value: Any) -> None:
622
+ """dict[key] = value - 即時書き込み + メモリ更新"""
623
+ self._check_connection()
624
+ # メモリ更新
625
+ if self._lru_mode or (hasattr(self._cache, "_max_size") and self._cache._max_size):
626
+ self._cache.set(key, value)
627
+ else:
628
+ self._data[key] = value
629
+ self._cached_keys.add(key)
630
+ # 即時書き込み
631
+ self._write_to_db(key, value)
632
+
633
+ def __delitem__(self, key: str) -> None:
634
+ """del dict[key] - 即時削除"""
635
+ if not self._ensure_cached(key):
636
+ raise KeyError(key)
637
+ # メモリから削除
638
+ if self._lru_mode:
639
+ self._cache.delete(key)
640
+ else:
641
+ self._data.pop(key, None)
642
+ self._cached_keys.discard(key)
643
+
644
+ # DBから削除
645
+ self._delete_from_db(key)
646
+
647
+ def __contains__(self, key: str) -> bool:
648
+ """
649
+ key in dict - キーの存在確認
650
+
651
+ キャッシュにある場合はO(1)、ない場合は軽量なEXISTSクエリを使用。
652
+ 存在確認のみの場合、value全体を読み込まないため高速。
653
+ """
654
+ # FAST PATH
655
+ if key in self._cached_keys:
656
+ return key in self._data
657
+
658
+ # 軽量な存在確認クエリ(valueを読み込まない)
659
+ with self._lock:
660
+ cursor = self._connection.execute(
661
+ f"SELECT 1 FROM {self._table} WHERE key = ? LIMIT 1", # nosec
662
+ (key,), # nosec
663
+ )
664
+ exists = cursor.fetchone() is not None
665
+
666
+ if exists:
667
+ # 存在をマークするが、値は読み込まない(次回アクセス時に遅延ロード)
668
+ if self._lru_mode:
669
+ self._cache.mark_cached(key)
670
+ else:
671
+ self._cached_keys.add(key)
672
+ return True
673
+ else:
674
+ # 存在しないこともキャッシュ
675
+ if not self._lru_mode:
676
+ self._cached_keys.add(key)
677
+ return False
678
+
679
+ def __len__(self) -> int:
680
+ """len(dict) - DBの実際の件数を返す"""
681
+ with self._lock:
682
+ cursor = self._connection.execute(
683
+ f"SELECT COUNT(*) FROM {self._table}" # nosec
684
+ )
685
+ return cursor.fetchone()[0]
686
+
687
+ def __iter__(self) -> Iterator[str]:
688
+ """for key in dict"""
689
+ return iter(self.keys())
690
+
691
+ def __repr__(self) -> str:
692
+ return f"NanaSQLite({self._db_path!r}, table={self._table!r}, cached={self._cache.size})"
693
+
694
+ # ==================== Dict Methods ====================
695
+
696
+ def keys(self) -> list:
697
+ """全キーを取得(DBから)"""
698
+ return self._get_all_keys_from_db()
699
+
700
+ def values(self) -> list:
701
+ """全値を取得(一括ロードしてからメモリから)"""
702
+ self._check_connection()
703
+ self.load_all()
704
+ return list(self._cache.get_data().values())
705
+
706
+ def items(self) -> list:
707
+ """全アイテムを取得(一括ロードしてからメモリから)"""
708
+ self.load_all()
709
+ return list(self._cache.get_data().items())
710
+
711
+ def get(self, key: str, default: Any = None) -> Any:
712
+ """dict.get(key, default)"""
713
+ if self._ensure_cached(key):
714
+ if self._lru_mode:
715
+ return self._cache.get(key)
716
+ return self._data[key]
717
+ return default
718
+
719
+ def get_fresh(self, key: str, default: Any = None) -> Any:
720
+ """
721
+ DBから直接読み込み、キャッシュを更新して値を返す
722
+
723
+ キャッシュをバイパスしてDBから最新の値を取得する。
724
+ `execute()`でDBを直接変更した後などに使用。
725
+
726
+ 通常の`get()`よりオーバーヘッドがあるため、
727
+ キャッシュとDBの不整合が想定される場合のみ使用推奨。
728
+
729
+ Args:
730
+ key: 取得するキー
731
+ default: キーが存在しない場合のデフォルト値
732
+
733
+ Returns:
734
+ DBから取得した最新の値(存在しない場合はdefault)
735
+
736
+ Example:
737
+ >>> db.execute("UPDATE data SET value = ? WHERE key = ?", ('"new"', "key"))
738
+ >>> value = db.get_fresh("key") # DBから最新値を取得
739
+ """
740
+ # DBから直接読み込み
741
+ value = self._read_from_db(key)
742
+
743
+ if value is not None:
744
+ # キャッシュを更新
745
+ if self._lru_mode:
746
+ self._cache.set(key, value)
747
+ else:
748
+ self._data[key] = value
749
+ self._cached_keys.add(key)
750
+ return value
751
+ else:
752
+ # 存在しない場合はキャッシュからも削除
753
+ if self._lru_mode:
754
+ self._cache.delete(key)
755
+ else:
756
+ self._data.pop(key, None)
757
+ self._cached_keys.add(key) # 「存在しない」ことをマーク
758
+ return default
759
+
760
+ def batch_get(self, keys: list[str]) -> dict[str, Any]:
761
+ """
762
+ 複数のキーを一度に取得(効率的な一括ロード)
763
+
764
+ 1回の `SELECT IN (...)` クエリで複数のキーをDBから取得する。
765
+ 取得した値は自動的にキャッシュに保存される。
766
+
767
+ Args:
768
+ keys: 取得するキーのリスト
769
+
770
+ Returns:
771
+ 取得に成功したキーと値の dict
772
+
773
+ Example:
774
+ >>> results = db.batch_get(["user1", "user2", "user3"])
775
+ >>> print(results) # {"user1": {...}, "user2": {...}}
776
+ """
777
+ if not keys:
778
+ return {}
779
+
780
+ results = {}
781
+ missing_keys = []
782
+
783
+ # 1. キャッシュから取得可能なものをチェック
784
+ for key in keys:
785
+ if self._cache.is_cached(key):
786
+ val = self._cache.get(key)
787
+ if val is not None:
788
+ results[key] = val
789
+ else:
790
+ missing_keys.append(key)
791
+
792
+ if not missing_keys:
793
+ return results
794
+
795
+ # 2. DBから足りない分を一括取得
796
+ placeholders = ",".join(["?"] * len(missing_keys))
797
+ sql = f"SELECT key, value FROM {self._table} WHERE key IN ({placeholders})" # nosec
798
+
799
+ with self._lock:
800
+ cursor = self._connection.execute(sql, tuple(missing_keys))
801
+ for key, val_str in cursor:
802
+ value = self._deserialize(val_str)
803
+ self._cache.set(key, value)
804
+ results[key] = value
805
+
806
+ # 3. DBにも存在しなかったキーを「存在しない」としてキャッシュ
807
+ found_keys = set(results.keys())
808
+ for key in missing_keys:
809
+ if key not in found_keys:
810
+ self._cache.mark_cached(key)
811
+
812
+ return results
813
+
814
+ def pop(self, key: str, *args) -> Any:
815
+ """dict.pop(key[, default])"""
816
+ self._check_connection()
817
+ if self._ensure_cached(key):
818
+ value = self._cache.get(key)
819
+ self._cache.delete(key)
820
+ self._delete_from_db(key)
821
+ return value
822
+ if args:
823
+ return args[0]
824
+ raise KeyError(key)
825
+
826
+ def update(self, mapping: dict = None, **kwargs) -> None:
827
+ """dict.update(mapping) - 一括更新"""
828
+ if mapping:
829
+ for key, value in mapping.items():
830
+ self[key] = value
831
+ for key, value in kwargs.items():
832
+ self[key] = value
833
+
834
+ def clear(self) -> None:
835
+ """dict.clear() - 全削除"""
836
+ self._cache.clear()
837
+ self._all_loaded = False
838
+ with self._lock:
839
+ self._connection.execute(f"DELETE FROM {self._table}") # nosec
840
+
841
+ def setdefault(self, key: str, default: Any = None) -> Any:
842
+ """dict.setdefault(key, default)"""
843
+ if self._ensure_cached(key):
844
+ return self._cache.get(key)
845
+ self[key] = default
846
+ return default
847
+
848
+ # ==================== Special Methods ====================
849
+
850
+ def load_all(self) -> None:
851
+ """一括読み込み: 全データをメモリに展開"""
852
+ if self._all_loaded:
853
+ return
854
+
855
+ with self._lock:
856
+ cursor = self._connection.execute(
857
+ f"SELECT key, value FROM {self._table}" # nosec
858
+ )
859
+ rows = list(cursor) # ロック内でフェッチ
860
+
861
+ for key, value in rows:
862
+ self._cache.set(key, self._deserialize(value))
863
+
864
+ self._all_loaded = True
865
+
866
+ def refresh(self, key: str = None) -> None:
867
+ """
868
+ キャッシュを更新(DBから再読み込み)
869
+
870
+ Args:
871
+ key: 特定のキーのみ更新。Noneの場合は全キャッシュをクリアして再読み込み
872
+ """
873
+ if key is not None:
874
+ # FAST PATH for performance
875
+ if not self._lru_mode:
876
+ self._data.pop(key, None)
877
+ self._cached_keys.discard(key)
878
+ else:
879
+ self._cache.invalidate(key)
880
+ self._ensure_cached(key)
881
+ else:
882
+ self.clear_cache()
883
+
884
+ def is_cached(self, key: str) -> bool:
885
+ """キーがキャッシュ済みかどうか"""
886
+ # FAST PATH for performance
887
+ if not self._lru_mode:
888
+ return key in self._cached_keys
889
+ return self._cache.is_cached(key)
890
+
891
+ def batch_update(self, mapping: dict[str, Any]) -> None:
892
+ """
893
+ 一括書き込み(トランザクション + executemany使用で超高速)
894
+
895
+ 大量のデータを一度に書き込む場合、通常のupdateより10-100倍高速。
896
+ v1.0.3rc5でexecutemanyによる最適化を追加。
897
+
898
+ Args:
899
+ mapping: 書き込むキーと値のdict
900
+
901
+ Returns:
902
+ None
903
+
904
+ Example:
905
+ >>> db.batch_update({"key1": "value1", "key2": "value2", ...})
906
+ """
907
+ if not mapping:
908
+ return # 空の場合は何もしない
909
+
910
+ cursor = self._connection.cursor()
911
+ cursor.execute("BEGIN IMMEDIATE")
912
+ try:
913
+ # 事前にシリアライズしてexecutemany用のタプルリストを作成
914
+ params = [(key, self._serialize(value)) for key, value in mapping.items()]
915
+ cursor.executemany(
916
+ f"INSERT OR REPLACE INTO {self._table} (key, value) VALUES (?, ?)", # nosec
917
+ params,
918
+ )
919
+ # キャッシュ更新
920
+ for key, value in mapping.items():
921
+ if self._lru_mode:
922
+ self._cache.set(key, value)
923
+ else:
924
+ self._data[key] = value
925
+ self._cached_keys.add(key)
926
+ cursor.execute("COMMIT")
927
+ except Exception:
928
+ cursor.execute("ROLLBACK")
929
+ raise
930
+
931
+ def batch_delete(self, keys: list[str]) -> None:
932
+ """
933
+ 一括削除(トランザクション + executemany使用で高速)
934
+
935
+ v1.0.3rc5でexecutemanyによる最適化を追加。
936
+
937
+ Args:
938
+ keys: 削除するキーのリスト
939
+
940
+ Returns:
941
+ None
942
+ """
943
+ self._check_connection()
944
+ if not keys:
945
+ return # 空の場合は何もしない
946
+
947
+ cursor = self._connection.cursor()
948
+ cursor.execute("BEGIN IMMEDIATE")
949
+ try:
950
+ # executemany用のタプルリストを作成
951
+ params = [(key,) for key in keys]
952
+ cursor.executemany(
953
+ f"DELETE FROM {self._table} WHERE key = ?", # nosec
954
+ params,
955
+ )
956
+ # キャッシュ更新
957
+ for key in keys:
958
+ if self._lru_mode:
959
+ self._cache.delete(key)
960
+ else:
961
+ self._data.pop(key, None)
962
+ self._cached_keys.discard(key)
963
+ cursor.execute("COMMIT")
964
+ except Exception:
965
+ cursor.execute("ROLLBACK")
966
+ raise
967
+
968
+ def to_dict(self) -> dict:
969
+ """全データをPython dictとして取得"""
970
+ self._check_connection()
971
+ self.load_all()
972
+ return dict(self._data)
973
+
974
+ def copy(self) -> dict:
975
+ """浅いコピーを作成(標準dictを返す)"""
976
+ return self.to_dict()
977
+
978
+ def clear_cache(self) -> None:
979
+ """
980
+ メモリキャッシュをクリア
981
+
982
+ DBのデータは削除せず、メモリ上のキャッシュのみ破棄します。
983
+ """
984
+ self._cache.clear()
985
+ self._all_loaded = False
986
+
987
+ def _delete_from_db_on_expire(self, key: str) -> None:
988
+ """有効期限切れ時にDBからデータを削除 (内部用)"""
989
+ with self._lock:
990
+ if self._is_closed:
991
+ return
992
+ try:
993
+ self._connection.execute(f'DELETE FROM "{self._table}" WHERE key = ?', (key,)) # nosec
994
+ except apsw.Error as e:
995
+ logger.error(f"SQL error during background expiration for key '{key}': {e}")
996
+
997
+ def close(self) -> None:
998
+ """
999
+ データベース接続を閉じる
1000
+
1001
+ 注意: table()メソッドで作成されたインスタンスは接続を共有しているため、
1002
+ 接続の所有者(最初に作成されたインスタンス)のみが接続を閉じます。
1003
+
1004
+ Raises:
1005
+ NanaSQLiteTransactionError: トランザクション中にクローズを試みた場合
1006
+ """
1007
+ if self._is_closed:
1008
+ return # 既に閉じられている場合は何もしない
1009
+
1010
+ if self._in_transaction:
1011
+ raise NanaSQLiteTransactionError(
1012
+ "Cannot close connection while transaction is in progress. Please commit or rollback first."
1013
+ )
1014
+
1015
+ # 子インスタンスに通知
1016
+ for child in self._child_instances:
1017
+ child._mark_parent_closed()
1018
+
1019
+ self._child_instances.clear()
1020
+ self._is_closed = True
1021
+
1022
+ if self._is_connection_owner:
1023
+ try:
1024
+ self._connection.close()
1025
+ except apsw.Error as e:
1026
+ # 接続クローズの失敗は警告に留める
1027
+ import warnings
1028
+
1029
+ warnings.warn(f"Failed to close database connection: {e}", stacklevel=2)
1030
+
1031
+ def __enter__(self):
1032
+ """コンテキストマネージャ対応"""
1033
+ return self
1034
+
1035
+ def __exit__(self, exc_type, exc_val, exc_tb):
1036
+ """コンテキストマネージャ対応"""
1037
+ self.close()
1038
+ return False
1039
+
1040
+ # ==================== Pydantic Support ====================
1041
+
1042
+ def set_model(self, key: str, model: Any) -> None:
1043
+ """
1044
+ Pydanticモデルを保存
1045
+
1046
+ Pydanticモデル(BaseModelを継承したクラス)をシリアライズして保存。
1047
+ model_dump()メソッドを使用してdictに変換し、モデルのクラス情報も保存。
1048
+
1049
+ Args:
1050
+ key: 保存するキー
1051
+ model: Pydanticモデルのインスタンス
1052
+
1053
+ Example:
1054
+ >>> from pydantic import BaseModel
1055
+ >>> class User(BaseModel):
1056
+ ... name: str
1057
+ ... age: int
1058
+ >>> user = User(name="Nana", age=20)
1059
+ >>> db.set_model("user", user)
1060
+ """
1061
+ try:
1062
+ # Pydanticモデルかチェック (model_dump メソッドの存在で判定)
1063
+ if hasattr(model, "model_dump"):
1064
+ data = {
1065
+ "__pydantic_model__": f"{model.__class__.__module__}.{model.__class__.__qualname__}",
1066
+ "__pydantic_data__": model.model_dump(),
1067
+ }
1068
+ self[key] = data
1069
+ else:
1070
+ raise TypeError(f"Object of type {type(model)} is not a Pydantic model")
1071
+ except Exception as e:
1072
+ raise TypeError(f"Failed to serialize Pydantic model: {e}")
1073
+
1074
+ def get_model(self, key: str, model_class: type = None) -> Any:
1075
+ """
1076
+ Pydanticモデルを取得
1077
+
1078
+ 保存されたPydanticモデルをデシリアライズして復元。
1079
+ model_classが指定されていない場合は、保存時のクラス情報を使用。
1080
+
1081
+ Args:
1082
+ key: 取得するキー
1083
+ model_class: Pydanticモデルのクラス(Noneの場合は自動検出を試みる)
1084
+
1085
+ Returns:
1086
+ Pydanticモデルのインスタンス
1087
+
1088
+ Example:
1089
+ >>> user = db.get_model("user", User)
1090
+ >>> print(user.name) # "Nana"
1091
+ """
1092
+ data = self[key]
1093
+
1094
+ if isinstance(data, dict) and "__pydantic_model__" in data and "__pydantic_data__" in data:
1095
+ if model_class is None:
1096
+ # 自動検出は複雑なため、model_classを推奨
1097
+ raise ValueError("model_class must be provided for get_model()")
1098
+
1099
+ # Pydanticモデルとして復元
1100
+ try:
1101
+ return model_class(**data["__pydantic_data__"])
1102
+ except Exception as e:
1103
+ raise ValueError(f"Failed to deserialize Pydantic model: {e}")
1104
+ elif model_class is not None:
1105
+ # 通常のdictをPydanticモデルに変換
1106
+ try:
1107
+ return model_class(**data)
1108
+ except Exception as e:
1109
+ raise ValueError(f"Failed to create Pydantic model from data: {e}")
1110
+ else:
1111
+ raise ValueError("Data is not a Pydantic model and no model_class provided")
1112
+
1113
+ # ==================== Direct SQL Execution ====================
1114
+
1115
+ def execute(self, sql: str, parameters: tuple | None = None) -> apsw.Cursor:
1116
+ """
1117
+ SQLを直接実行
1118
+
1119
+ 任意のSQL文を実行できる。SELECT、INSERT、UPDATE、DELETEなど。
1120
+ パラメータバインディングをサポート(SQLインジェクション対策)。
1121
+
1122
+ .. warning::
1123
+ このメソッドで直接デフォルトテーブル(data)を操作した場合、
1124
+ 内部キャッシュ(_data)と不整合が発生する可能性があります。
1125
+ キャッシュを更新するには `refresh()` を呼び出してください。
1126
+
1127
+ Args:
1128
+ sql: 実行するSQL文
1129
+ parameters: SQLのパラメータ(?プレースホルダー用)
1130
+
1131
+ Returns:
1132
+ APSWのCursorオブジェクト(結果の取得に使用)
1133
+
1134
+ Raises:
1135
+ NanaSQLiteConnectionError: 接続が閉じられている場合
1136
+ NanaSQLiteDatabaseError: SQL実行エラー
1137
+
1138
+ Example:
1139
+ >>> cursor = db.execute("SELECT * FROM data WHERE key LIKE ?", ("user%",))
1140
+ >>> for row in cursor:
1141
+ ... print(row)
1142
+
1143
+ # キャッシュ更新が必要な場合:
1144
+ >>> db.execute("UPDATE data SET value = ? WHERE key = ?", ('"new"', "key"))
1145
+ >>> db.refresh("key") # キャッシュを更新
1146
+ """
1147
+ self._check_connection()
1148
+
1149
+ try:
1150
+ with self._lock:
1151
+ if parameters is None:
1152
+ return self._connection.execute(sql)
1153
+ else:
1154
+ return self._connection.execute(sql, parameters)
1155
+ except apsw.Error as e:
1156
+ raise NanaSQLiteDatabaseError(f"Failed to execute SQL: {e}", original_error=e) from e
1157
+
1158
+ def execute_many(self, sql: str, parameters_list: list[tuple]) -> None:
1159
+ """
1160
+ SQLを複数のパラメータで一括実行
1161
+
1162
+ 同じSQL文を複数のパラメータセットで実行(トランザクション使用)。
1163
+ 大量のINSERTやUPDATEを高速に実行できる。
1164
+
1165
+ Args:
1166
+ sql: 実行するSQL文
1167
+ parameters_list: パラメータのリスト
1168
+
1169
+ Example:
1170
+ >>> db.execute_many(
1171
+ ... "INSERT OR REPLACE INTO custom (id, name) VALUES (?, ?)",
1172
+ ... [(1, "Alice"), (2, "Bob"), (3, "Charlie")]
1173
+ ... )
1174
+ """
1175
+ with self._lock:
1176
+ cursor = self._connection.cursor()
1177
+ cursor.execute("BEGIN IMMEDIATE")
1178
+ try:
1179
+ for parameters in parameters_list:
1180
+ cursor.execute(sql, parameters)
1181
+ cursor.execute("COMMIT")
1182
+ except apsw.Error:
1183
+ cursor.execute("ROLLBACK")
1184
+ raise
1185
+
1186
+ def fetch_one(self, sql: str, parameters: tuple = None) -> tuple | None:
1187
+ """
1188
+ SQLを実行して1行取得
1189
+
1190
+ Args:
1191
+ sql: 実行するSQL文
1192
+ parameters: SQLのパラメータ
1193
+
1194
+ Returns:
1195
+ 1行の結果(tuple)、結果がない場合はNone
1196
+
1197
+ Example:
1198
+ >>> row = db.fetch_one("SELECT value FROM data WHERE key = ?", ("user",))
1199
+ >>> print(row[0])
1200
+ """
1201
+ cursor = self.execute(sql, parameters)
1202
+ return cursor.fetchone()
1203
+
1204
+ def fetch_all(self, sql: str, parameters: tuple = None) -> list[tuple]:
1205
+ """
1206
+ SQLを実行して全行取得
1207
+
1208
+ Args:
1209
+ sql: 実行するSQL文
1210
+ parameters: SQLのパラメータ
1211
+
1212
+ Returns:
1213
+ 全行の結果(tupleのリスト)
1214
+
1215
+ Example:
1216
+ >>> rows = db.fetch_all("SELECT key, value FROM data WHERE key LIKE ?", ("user%",))
1217
+ >>> for key, value in rows:
1218
+ ... print(key, value)
1219
+ """
1220
+ cursor = self.execute(sql, parameters)
1221
+ return cursor.fetchall()
1222
+
1223
+ # ==================== SQLite Wrapper Functions ====================
1224
+
1225
+ def create_table(self, table_name: str, columns: dict, if_not_exists: bool = True, primary_key: str = None) -> None:
1226
+ """
1227
+ テーブルを作成
1228
+
1229
+ Args:
1230
+ table_name: テーブル名
1231
+ columns: カラム定義のdict(カラム名: SQL型)
1232
+ if_not_exists: Trueの場合、存在しない場合のみ作成
1233
+ primary_key: プライマリキーのカラム名(Noneの場合は指定なし)
1234
+
1235
+ Example:
1236
+ >>> db.create_table("users", {
1237
+ ... "id": "INTEGER PRIMARY KEY",
1238
+ ... "name": "TEXT NOT NULL",
1239
+ ... "email": "TEXT UNIQUE",
1240
+ ... "age": "INTEGER"
1241
+ ... })
1242
+ >>> db.create_table("posts", {
1243
+ ... "id": "INTEGER",
1244
+ ... "title": "TEXT",
1245
+ ... "content": "TEXT"
1246
+ ... }, primary_key="id")
1247
+ """
1248
+ if_not_exists_clause = "IF NOT EXISTS " if if_not_exists else ""
1249
+ safe_table_name = self._sanitize_identifier(table_name)
1250
+
1251
+ column_defs = []
1252
+ for col_name, col_type in columns.items():
1253
+ safe_col_name = self._sanitize_identifier(col_name)
1254
+ column_defs.append(f"{safe_col_name} {col_type}")
1255
+
1256
+ if primary_key:
1257
+ safe_pk = self._sanitize_identifier(primary_key)
1258
+ if not any(primary_key.upper() in col.upper() and "PRIMARY KEY" in col.upper() for col in column_defs):
1259
+ column_defs.append(f"PRIMARY KEY ({safe_pk})")
1260
+
1261
+ columns_sql = ", ".join(column_defs)
1262
+ sql = f"CREATE TABLE {if_not_exists_clause}{safe_table_name} ({columns_sql})"
1263
+
1264
+ self.execute(sql)
1265
+
1266
+ def create_index(
1267
+ self, index_name: str, table_name: str, columns: list[str], unique: bool = False, if_not_exists: bool = True
1268
+ ) -> None:
1269
+ """
1270
+ インデックスを作成
1271
+
1272
+ Args:
1273
+ index_name: インデックス名
1274
+ table_name: テーブル名
1275
+ columns: インデックスを作成するカラムのリスト
1276
+ unique: Trueの場合、ユニークインデックスを作成
1277
+ if_not_exists: Trueの場合、存在しない場合のみ作成
1278
+
1279
+ Example:
1280
+ >>> db.create_index("idx_users_email", "users", ["email"], unique=True)
1281
+ >>> db.create_index("idx_posts_user", "posts", ["user_id", "created_at"])
1282
+ """
1283
+ unique_clause = "UNIQUE " if unique else ""
1284
+ if_not_exists_clause = "IF NOT EXISTS " if if_not_exists else ""
1285
+ safe_index_name = self._sanitize_identifier(index_name)
1286
+ safe_table_name = self._sanitize_identifier(table_name)
1287
+ safe_columns = [self._sanitize_identifier(col) for col in columns]
1288
+ columns_sql = ", ".join(safe_columns)
1289
+
1290
+ sql = (
1291
+ f"CREATE {unique_clause}INDEX {if_not_exists_clause}{safe_index_name} ON {safe_table_name} ({columns_sql})"
1292
+ )
1293
+ self.execute(sql)
1294
+
1295
+ def query(
1296
+ self,
1297
+ table_name: str = None,
1298
+ columns: list[str] = None,
1299
+ where: str = None,
1300
+ parameters: tuple = None,
1301
+ order_by: str = None,
1302
+ limit: int = None,
1303
+ strict_sql_validation: bool = None,
1304
+ allowed_sql_functions: list[str] = None,
1305
+ forbidden_sql_functions: list[str] = None,
1306
+ override_allowed: bool = False,
1307
+ ) -> list[dict]:
1308
+ """
1309
+ シンプルなSELECTクエリを実行
1310
+
1311
+ Args:
1312
+ table_name: テーブル名(Noneの場合はデフォルトテーブル)
1313
+ columns: 取得するカラムのリスト(Noneの場合は全カラム)
1314
+ where: WHERE句の条件(パラメータバインディング使用推奨)
1315
+ parameters: WHERE句のパラメータ
1316
+ order_by: ORDER BY句
1317
+ limit: LIMIT句
1318
+ strict_sql_validation: Trueの場合、未許可の関数等を含むクエリを拒否
1319
+ allowed_sql_functions: このクエリで一時的に許可するSQL関数のリスト
1320
+ forbidden_sql_functions: このクエリで一時的に禁止するSQL関数のリスト
1321
+ override_allowed: Trueの場合、インスタンス許可設定を無視
1322
+
1323
+ Returns:
1324
+ 結果のリスト(各行はdict)
1325
+
1326
+ Example:
1327
+ >>> # デフォルトテーブルから全データ取得
1328
+ >>> results = db.query()
1329
+
1330
+ >>> # 条件付き検索
1331
+ >>> results = db.query(
1332
+ ... table_name="users",
1333
+ ... columns=["id", "name", "email"],
1334
+ ... where="age > ?",
1335
+ ... parameters=(20,),
1336
+ ... order_by="name ASC",
1337
+ ... limit=10
1338
+ ... )
1339
+ """
1340
+ if table_name is None:
1341
+ table_name = self._table
1342
+
1343
+ safe_table_name = self._sanitize_identifier(table_name)
1344
+
1345
+ # バリデーション
1346
+ self._validate_expression(
1347
+ where,
1348
+ strict_sql_validation,
1349
+ allowed_sql_functions,
1350
+ forbidden_sql_functions,
1351
+ override_allowed,
1352
+ context="where",
1353
+ )
1354
+ self._validate_expression(
1355
+ order_by,
1356
+ strict_sql_validation,
1357
+ allowed_sql_functions,
1358
+ forbidden_sql_functions,
1359
+ override_allowed,
1360
+ context="order_by",
1361
+ )
1362
+ if columns:
1363
+ for col in columns:
1364
+ # 関数使用の可能性を考慮して識別子サニタイズは行わないがバリデーションは行う
1365
+ self._validate_expression(
1366
+ col,
1367
+ strict_sql_validation,
1368
+ allowed_sql_functions,
1369
+ forbidden_sql_functions,
1370
+ override_allowed,
1371
+ context="column",
1372
+ )
1373
+
1374
+ # カラム指定
1375
+ if columns is None:
1376
+ columns_sql = "*"
1377
+ # カラム名は後でPRAGMAから取得
1378
+ else:
1379
+ # 識別子(カラム名のみ)の場合はサニタイズ、式の場合はそのまま(バリデーション済み)
1380
+ safe_cols = []
1381
+ for col in columns:
1382
+ if IDENTIFIER_PATTERN.match(col):
1383
+ safe_cols.append(self._sanitize_identifier(col))
1384
+ else:
1385
+ safe_cols.append(col)
1386
+ columns_sql = ", ".join(safe_cols)
1387
+
1388
+ # Validate limit is an integer and non-negative if provided
1389
+ if limit is not None:
1390
+ if not isinstance(limit, int):
1391
+ raise ValueError(f"limit must be an integer, got {type(limit).__name__}")
1392
+ if limit < 0:
1393
+ raise ValueError("limit must be non-negative")
1394
+
1395
+ # SQL構築
1396
+ sql = f"SELECT {columns_sql} FROM {safe_table_name}" # nosec
1397
+
1398
+ if where:
1399
+ sql += f" WHERE {where}"
1400
+
1401
+ if order_by:
1402
+ sql += f" ORDER BY {order_by}"
1403
+
1404
+ if limit is not None:
1405
+ sql += f" LIMIT {limit}"
1406
+
1407
+ # 実行
1408
+ cursor = self.execute(sql, parameters)
1409
+
1410
+ # カラム名取得
1411
+ if columns is None:
1412
+ # 全カラムの場合、テーブル情報から取得
1413
+ pragma_cursor = self.execute(f"PRAGMA table_info({safe_table_name})")
1414
+ col_names = [row[1] for row in pragma_cursor]
1415
+ else:
1416
+ # Extract aliases from AS clauses, similar to query_with_pagination
1417
+ col_names = []
1418
+ for col in columns:
1419
+ parts = re.split(r"\s+as\s+", col, flags=re.IGNORECASE)
1420
+ if len(parts) > 1:
1421
+ # Use the alias (after AS)
1422
+ col_names.append(parts[-1].strip().strip('"').strip("'"))
1423
+ else:
1424
+ # Use the column expression as-is
1425
+ col_names.append(col.strip())
1426
+
1427
+ # 結果をdictのリストに変換
1428
+ results = []
1429
+ for row in cursor:
1430
+ results.append(dict(zip(col_names, row)))
1431
+
1432
+ return results
1433
+
1434
+ def table_exists(self, table_name: str) -> bool:
1435
+ """
1436
+ テーブルの存在確認
1437
+
1438
+ Args:
1439
+ table_name: テーブル名
1440
+
1441
+ Returns:
1442
+ 存在する場合True、しない場合False
1443
+
1444
+ Example:
1445
+ >>> if db.table_exists("users"):
1446
+ ... print("users table exists")
1447
+ """
1448
+ cursor = self.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table_name,))
1449
+ return cursor.fetchone() is not None
1450
+
1451
+ def list_tables(self) -> list[str]:
1452
+ """
1453
+ データベース内の全テーブル一覧を取得
1454
+
1455
+ Returns:
1456
+ テーブル名のリスト
1457
+
1458
+ Example:
1459
+ >>> tables = db.list_tables()
1460
+ >>> print(tables) # ['data', 'users', 'posts']
1461
+ """
1462
+ cursor = self.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
1463
+ return [row[0] for row in cursor]
1464
+
1465
+ def drop_table(self, table_name: str, if_exists: bool = True) -> None:
1466
+ """
1467
+ テーブルを削除
1468
+
1469
+ Args:
1470
+ table_name: テーブル名
1471
+ if_exists: Trueの場合、存在する場合のみ削除(エラーを防ぐ)
1472
+
1473
+ Example:
1474
+ >>> db.drop_table("old_table")
1475
+ >>> db.drop_table("temp", if_exists=True)
1476
+ """
1477
+ if_exists_clause = "IF EXISTS " if if_exists else ""
1478
+ safe_table_name = self._sanitize_identifier(table_name)
1479
+ sql = f"DROP TABLE {if_exists_clause}{safe_table_name}"
1480
+ self.execute(sql)
1481
+
1482
+ def drop_index(self, index_name: str, if_exists: bool = True) -> None:
1483
+ """
1484
+ インデックスを削除
1485
+
1486
+ Args:
1487
+ index_name: インデックス名
1488
+ if_exists: Trueの場合、存在する場合のみ削除
1489
+
1490
+ Example:
1491
+ >>> db.drop_index("idx_users_email")
1492
+ """
1493
+ if_exists_clause = "IF EXISTS " if if_exists else ""
1494
+ safe_index_name = self._sanitize_identifier(index_name)
1495
+ sql = f"DROP INDEX {if_exists_clause}{safe_index_name}"
1496
+ self.execute(sql)
1497
+
1498
+ def alter_table_add_column(self, table_name: str, column_name: str, column_type: str, default: Any = None) -> None:
1499
+ """
1500
+ 既存テーブルにカラムを追加
1501
+
1502
+ Args:
1503
+ table_name: テーブル名
1504
+ column_name: カラム名
1505
+ column_type: カラムの型(SQL型)
1506
+ default: デフォルト値(Noneの場合は指定なし)
1507
+
1508
+ Example:
1509
+ >>> db.alter_table_add_column("users", "phone", "TEXT")
1510
+ >>> db.alter_table_add_column("users", "status", "TEXT", default="'active'")
1511
+ """
1512
+ safe_table_name = self._sanitize_identifier(table_name)
1513
+ safe_column_name = self._sanitize_identifier(column_name)
1514
+ # column_type is a SQL type string - validate it doesn't contain dangerous characters
1515
+ # Also check for closing parenthesis which could break out of ALTER TABLE structure
1516
+ if any(c in column_type for c in [";", "'", ")"]) or "--" in column_type or "/*" in column_type:
1517
+ raise ValueError(f"Invalid or dangerous column type: {column_type}")
1518
+
1519
+ sql = f"ALTER TABLE {safe_table_name} ADD COLUMN {safe_column_name} {column_type}"
1520
+ if default is not None:
1521
+ # For default values: if it's a string, ensure it's properly quoted and escaped
1522
+ if isinstance(default, str):
1523
+ # Strip leading/trailing single quotes if present, then escape and re-quote
1524
+ stripped = default
1525
+ if stripped.startswith("'") and stripped.endswith("'") and len(stripped) >= 2:
1526
+ stripped = stripped[1:-1]
1527
+ # Escape single quotes for SQL string literal (double them: ' becomes '')
1528
+ escaped_default = stripped.replace("'", "''")
1529
+ default = f"'{escaped_default}'"
1530
+ sql += f" DEFAULT {default}"
1531
+ self.execute(sql)
1532
+
1533
+ def get_table_schema(self, table_name: str) -> list[dict]:
1534
+ """
1535
+ テーブル構造を取得
1536
+
1537
+ Args:
1538
+ table_name: テーブル名
1539
+
1540
+ Returns:
1541
+ カラム情報のリスト(各カラムはdict)
1542
+
1543
+ Example:
1544
+ >>> schema = db.get_table_schema("users")
1545
+ >>> for col in schema:
1546
+ ... print(f"{col['name']}: {col['type']}")
1547
+ """
1548
+ safe_table_name = self._sanitize_identifier(table_name)
1549
+ cursor = self.execute(f"PRAGMA table_info({safe_table_name})")
1550
+ columns = []
1551
+ for row in cursor:
1552
+ columns.append(
1553
+ {
1554
+ "cid": row[0],
1555
+ "name": row[1],
1556
+ "type": row[2],
1557
+ "notnull": bool(row[3]),
1558
+ "default_value": row[4],
1559
+ "pk": bool(row[5]),
1560
+ }
1561
+ )
1562
+ return columns
1563
+
1564
+ def list_indexes(self, table_name: str = None) -> list[dict]:
1565
+ """
1566
+ インデックス一覧を取得
1567
+
1568
+ Args:
1569
+ table_name: テーブル名(Noneの場合は全インデックス)
1570
+
1571
+ Returns:
1572
+ インデックス情報のリスト
1573
+
1574
+ Example:
1575
+ >>> indexes = db.list_indexes("users")
1576
+ >>> for idx in indexes:
1577
+ ... print(f"{idx['name']}: {idx['columns']}")
1578
+ """
1579
+ if table_name:
1580
+ cursor = self.execute(
1581
+ "SELECT name, tbl_name, sql FROM sqlite_master WHERE type='index' AND tbl_name=? ORDER BY name",
1582
+ (table_name,),
1583
+ )
1584
+ else:
1585
+ cursor = self.execute("SELECT name, tbl_name, sql FROM sqlite_master WHERE type='index' ORDER BY name")
1586
+
1587
+ indexes = []
1588
+ for row in cursor:
1589
+ if row[0] and not row[0].startswith("sqlite_"): # Skip auto-created indexes
1590
+ indexes.append({"name": row[0], "table": row[1], "sql": row[2]})
1591
+ return indexes
1592
+
1593
+ # ==================== Data Operation Wrappers ====================
1594
+
1595
+ def sql_insert(self, table_name: str, data: dict) -> int:
1596
+ """
1597
+ dictから直接INSERT
1598
+
1599
+ Args:
1600
+ table_name: テーブル名
1601
+ data: カラム名と値のdict
1602
+
1603
+ Returns:
1604
+ 挿入されたROWID
1605
+
1606
+ Example:
1607
+ >>> rowid = db.sql_insert("users", {
1608
+ ... "name": "Alice",
1609
+ ... "email": "alice@example.com",
1610
+ ... "age": 25
1611
+ ... })
1612
+ """
1613
+ safe_table_name = self._sanitize_identifier(table_name)
1614
+ safe_columns = [self._sanitize_identifier(col) for col in data.keys()]
1615
+ values = list(data.values())
1616
+ placeholders = ", ".join(["?"] * len(values))
1617
+ columns_sql = ", ".join(safe_columns)
1618
+
1619
+ sql = f"INSERT INTO {safe_table_name} ({columns_sql}) VALUES ({placeholders})" # nosec
1620
+ self.execute(sql, tuple(values))
1621
+
1622
+ return self.get_last_insert_rowid()
1623
+
1624
+ def sql_update(self, table_name: str, data: dict, where: str, parameters: tuple = None) -> int:
1625
+ """
1626
+ dictとwhere条件でUPDATE
1627
+
1628
+ Args:
1629
+ table_name: テーブル名
1630
+ data: 更新するカラム名と値のdict
1631
+ where: WHERE句の条件
1632
+ parameters: WHERE句のパラメータ
1633
+
1634
+ Returns:
1635
+ 更新された行数
1636
+
1637
+ Example:
1638
+ >>> count = db.sql_update("users",
1639
+ ... {"age": 26, "status": "active"},
1640
+ ... "name = ?",
1641
+ ... ("Alice",)
1642
+ ... )
1643
+ """
1644
+ safe_table_name = self._sanitize_identifier(table_name)
1645
+ safe_set_items = [f"{self._sanitize_identifier(col)} = ?" for col in data.keys()]
1646
+ set_clause = ", ".join(safe_set_items)
1647
+ values = list(data.values())
1648
+
1649
+ sql = f"UPDATE {safe_table_name} SET {set_clause} WHERE {where}" # nosec
1650
+
1651
+ if parameters:
1652
+ values.extend(parameters)
1653
+
1654
+ self.execute(sql, tuple(values))
1655
+ return self._connection.changes()
1656
+
1657
+ def sql_delete(self, table_name: str, where: str, parameters: tuple = None) -> int:
1658
+ """
1659
+ where条件でDELETE
1660
+
1661
+ Args:
1662
+ table_name: テーブル名
1663
+ where: WHERE句の条件
1664
+ parameters: WHERE句のパラメータ
1665
+
1666
+ Returns:
1667
+ 削除された行数
1668
+
1669
+ Example:
1670
+ >>> count = db.sql_delete("users", "age < ?", (18,))
1671
+ """
1672
+ safe_table_name = self._sanitize_identifier(table_name)
1673
+ sql = f"DELETE FROM {safe_table_name} WHERE {where}" # nosec
1674
+ self.execute(sql, parameters)
1675
+ return self._connection.changes()
1676
+
1677
+ def upsert(self, table_name: str, data: dict, conflict_columns: list[str] = None) -> int:
1678
+ """
1679
+ INSERT OR REPLACE の簡易版(upsert)
1680
+
1681
+ Args:
1682
+ table_name: テーブル名
1683
+ data: カラム名と値のdict
1684
+ conflict_columns: 競合判定に使用するカラム(Noneの場合はINSERT OR REPLACE)
1685
+
1686
+ Returns:
1687
+ 挿入/更新されたROWID
1688
+
1689
+ Example:
1690
+ >>> # 単純なINSERT OR REPLACE
1691
+ >>> db.upsert("users", {"id": 1, "name": "Alice", "age": 25})
1692
+
1693
+ >>> # ON CONFLICT句を使用
1694
+ >>> db.upsert("users",
1695
+ ... {"email": "alice@example.com", "name": "Alice", "age": 26},
1696
+ ... conflict_columns=["email"]
1697
+ ... )
1698
+ """
1699
+ safe_table_name = self._sanitize_identifier(table_name)
1700
+ safe_columns = [self._sanitize_identifier(col) for col in data.keys()]
1701
+ values = list(data.values())
1702
+ placeholders = ", ".join(["?"] * len(values))
1703
+ columns_sql = ", ".join(safe_columns)
1704
+
1705
+ if conflict_columns:
1706
+ # ON CONFLICT を使用
1707
+ safe_conflict_cols = [self._sanitize_identifier(col) for col in conflict_columns]
1708
+ conflict_cols_sql = ", ".join(safe_conflict_cols)
1709
+
1710
+ update_items = [
1711
+ f"{self._sanitize_identifier(col)} = excluded.{self._sanitize_identifier(col)}"
1712
+ for col in data.keys()
1713
+ if col not in conflict_columns
1714
+ ]
1715
+
1716
+ if update_items:
1717
+ update_clause = ", ".join(update_items)
1718
+ else:
1719
+ # 全カラムが競合カラムの場合は、何もしない(既存データを保持)
1720
+ sql = f"INSERT INTO {safe_table_name} ({columns_sql}) VALUES ({placeholders}) " # nosec
1721
+ sql += f"ON CONFLICT({conflict_cols_sql}) DO NOTHING" # nosec
1722
+ self.execute(sql, tuple(values))
1723
+ # When DO NOTHING is triggered, no row is inserted, return 0
1724
+ # Check only the most recent operation's change count
1725
+ if self._connection.changes() == 0:
1726
+ return 0
1727
+ return self.get_last_insert_rowid()
1728
+
1729
+ sql = f"INSERT INTO {safe_table_name} ({columns_sql}) VALUES ({placeholders}) " # nosec
1730
+ sql += f"ON CONFLICT({conflict_cols_sql}) DO UPDATE SET {update_clause}" # nosec
1731
+ else:
1732
+ # INSERT OR REPLACE
1733
+ sql = f"INSERT OR REPLACE INTO {safe_table_name} ({columns_sql}) VALUES ({placeholders})" # nosec
1734
+
1735
+ self.execute(sql, tuple(values))
1736
+ return self.get_last_insert_rowid()
1737
+
1738
+ def count(
1739
+ self,
1740
+ table_name: str = None,
1741
+ where: str = None,
1742
+ parameters: tuple = None,
1743
+ strict_sql_validation: bool = None,
1744
+ allowed_sql_functions: list[str] = None,
1745
+ forbidden_sql_functions: list[str] = None,
1746
+ override_allowed: bool = False,
1747
+ ) -> int:
1748
+ """
1749
+ レコード数を取得
1750
+
1751
+ Args:
1752
+ table_name: テーブル名(Noneの場合はデフォルトテーブル)
1753
+ where: WHERE句の条件(オプション)
1754
+ parameters: WHERE句のパラメータ
1755
+ strict_sql_validation: Trueの場合、未許可の関数等を含むクエリを拒否
1756
+ allowed_sql_functions: このクエリで一時的に許可するSQL関数のリスト
1757
+ forbidden_sql_functions: このクエリで一時的に禁止するSQL関数のリスト
1758
+ override_allowed: Trueの場合、インスタンス許可設定を無視
1759
+
1760
+ Example:
1761
+ >>> total = db.count("users")
1762
+ >>> adults = db.count("users", "age >= ?", (18,))
1763
+ """
1764
+ if table_name is None:
1765
+ table_name = self._table
1766
+
1767
+ safe_table_name = self._sanitize_identifier(table_name)
1768
+
1769
+ # バリデーション
1770
+ self._validate_expression(
1771
+ where, strict_sql_validation, allowed_sql_functions, forbidden_sql_functions, override_allowed
1772
+ )
1773
+
1774
+ sql = f"SELECT COUNT(*) FROM {safe_table_name}" # nosec
1775
+ if where:
1776
+ sql += f" WHERE {where}"
1777
+
1778
+ cursor = self.execute(sql, parameters)
1779
+ return cursor.fetchone()[0]
1780
+
1781
+ def exists(self, table_name: str, where: str, parameters: tuple = None) -> bool:
1782
+ """
1783
+ レコードの存在確認
1784
+
1785
+ Args:
1786
+ table_name: テーブル名
1787
+ where: WHERE句の条件
1788
+ parameters: WHERE句のパラメータ
1789
+
1790
+ Returns:
1791
+ 存在する場合True
1792
+
1793
+ Example:
1794
+ >>> if db.exists("users", "email = ?", ("alice@example.com",)):
1795
+ ... print("User exists")
1796
+ """
1797
+ safe_table_name = self._sanitize_identifier(table_name)
1798
+ sql = f"SELECT EXISTS(SELECT 1 FROM {safe_table_name} WHERE {where})" # nosec
1799
+ cursor = self.execute(sql, parameters)
1800
+ return bool(cursor.fetchone()[0])
1801
+
1802
+ # ==================== Query Extensions ====================
1803
+
1804
+ def query_with_pagination(
1805
+ self,
1806
+ table_name: str = None,
1807
+ columns: list[str] = None,
1808
+ where: str = None,
1809
+ parameters: tuple = None,
1810
+ order_by: str = None,
1811
+ limit: int = None,
1812
+ offset: int = None,
1813
+ group_by: str = None,
1814
+ strict_sql_validation: bool = None,
1815
+ allowed_sql_functions: list[str] = None,
1816
+ forbidden_sql_functions: list[str] = None,
1817
+ override_allowed: bool = False,
1818
+ ) -> list[dict]:
1819
+ """
1820
+ 拡張されたクエリ(offset、group_by対応)
1821
+
1822
+ Args:
1823
+ table_name: テーブル名
1824
+ columns: 取得するカラム
1825
+ where: WHERE句
1826
+ parameters: パラメータ
1827
+ order_by: ORDER BY句
1828
+ limit: LIMIT句
1829
+ offset: OFFSET句(ページネーション用)
1830
+ group_by: GROUP BY句
1831
+ strict_sql_validation: Trueの場合、未許可の関数等を含むクエリを拒否
1832
+ allowed_sql_functions: このクエリで一時的に許可するSQL関数のリスト
1833
+ forbidden_sql_functions: このクエリで一時的に禁止するSQL関数のリスト
1834
+ override_allowed: Trueの場合、インスタンス許可設定を無視
1835
+
1836
+ Returns:
1837
+ 結果のリスト
1838
+
1839
+ Example:
1840
+ >>> # ページネーション
1841
+ >>> page2 = db.query_with_pagination("users",
1842
+ ... limit=10, offset=10, order_by="id ASC")
1843
+
1844
+ >>> # グループ集計
1845
+ >>> stats = db.query_with_pagination("orders",
1846
+ ... columns=["user_id", "COUNT(*) as order_count"],
1847
+ ... group_by="user_id"
1848
+ ... )
1849
+ """
1850
+ if table_name is None:
1851
+ table_name = self._table
1852
+
1853
+ safe_table_name = self._sanitize_identifier(table_name)
1854
+
1855
+ # バリデーション
1856
+ self._validate_expression(
1857
+ where,
1858
+ strict_sql_validation,
1859
+ allowed_sql_functions,
1860
+ forbidden_sql_functions,
1861
+ override_allowed,
1862
+ context="where",
1863
+ )
1864
+ self._validate_expression(
1865
+ order_by,
1866
+ strict_sql_validation,
1867
+ allowed_sql_functions,
1868
+ forbidden_sql_functions,
1869
+ override_allowed,
1870
+ context="order_by",
1871
+ )
1872
+ self._validate_expression(
1873
+ group_by,
1874
+ strict_sql_validation,
1875
+ allowed_sql_functions,
1876
+ forbidden_sql_functions,
1877
+ override_allowed,
1878
+ context="group_by",
1879
+ )
1880
+ if columns:
1881
+ for col in columns:
1882
+ self._validate_expression(
1883
+ col,
1884
+ strict_sql_validation,
1885
+ allowed_sql_functions,
1886
+ forbidden_sql_functions,
1887
+ override_allowed,
1888
+ context="column",
1889
+ )
1890
+
1891
+ # Validate limit and offset are non-negative integers if provided
1892
+ if limit is not None:
1893
+ if not isinstance(limit, int):
1894
+ raise ValueError(f"limit must be an integer, got {type(limit).__name__}")
1895
+ if limit < 0:
1896
+ raise ValueError("limit must be non-negative")
1897
+
1898
+ if offset is not None:
1899
+ if not isinstance(offset, int):
1900
+ raise ValueError(f"offset must be an integer, got {type(offset).__name__}")
1901
+ if offset < 0:
1902
+ raise ValueError("offset must be non-negative")
1903
+
1904
+ # カラム指定
1905
+ if columns is None:
1906
+ columns_sql = "*"
1907
+ else:
1908
+ # 識別子(カラム名のみ)の場合はサニタイズ、式の場合はそのまま(バリデーション済み)
1909
+ safe_cols = []
1910
+ for col in columns:
1911
+ if IDENTIFIER_PATTERN.match(col):
1912
+ safe_cols.append(self._sanitize_identifier(col))
1913
+ else:
1914
+ safe_cols.append(col)
1915
+ columns_sql = ", ".join(safe_cols)
1916
+
1917
+ # SQL構築
1918
+ sql = f"SELECT {columns_sql} FROM {safe_table_name}" # nosec
1919
+
1920
+ if where:
1921
+ sql += f" WHERE {where}"
1922
+
1923
+ if group_by:
1924
+ sql += f" GROUP BY {group_by}"
1925
+
1926
+ if order_by:
1927
+ sql += f" ORDER BY {order_by}"
1928
+
1929
+ if limit is not None:
1930
+ sql += f" LIMIT {limit}"
1931
+
1932
+ if offset is not None:
1933
+ sql += f" OFFSET {offset}"
1934
+
1935
+ # 実行
1936
+ cursor = self.execute(sql, parameters)
1937
+
1938
+ # カラム名取得
1939
+ if columns is None:
1940
+ pragma_cursor = self.execute(f"PRAGMA table_info({safe_table_name})")
1941
+ col_names = [row[1] for row in pragma_cursor]
1942
+ else:
1943
+ # カラム名からAS句を考慮(case-insensitive)
1944
+ col_names = []
1945
+ for col in columns:
1946
+ parts = re.split(r"\s+as\s+", col, flags=re.IGNORECASE)
1947
+ if len(parts) > 1:
1948
+ col_names.append(parts[-1].strip().strip('"').strip("'"))
1949
+ else:
1950
+ col_names.append(col.strip().strip('"').strip("'"))
1951
+
1952
+ # 結果をdictのリストに変換
1953
+ results = []
1954
+ for row in cursor:
1955
+ results.append(dict(zip(col_names, row)))
1956
+
1957
+ return results
1958
+
1959
+ # ==================== Utility Functions ====================
1960
+
1961
+ def vacuum(self) -> None:
1962
+ """
1963
+ データベースを最適化(VACUUM実行)
1964
+
1965
+ 削除されたレコードの領域を回収し、データベースファイルを最適化。
1966
+
1967
+ Example:
1968
+ >>> db.vacuum()
1969
+ """
1970
+ self.execute("VACUUM")
1971
+
1972
+ def get_db_size(self) -> int:
1973
+ """
1974
+ データベースファイルのサイズを取得(バイト単位)
1975
+
1976
+ Returns:
1977
+ データベースファイルのサイズ
1978
+
1979
+ Example:
1980
+ >>> size = db.get_db_size()
1981
+ >>> print(f"DB size: {size / 1024 / 1024:.2f} MB")
1982
+ """
1983
+ import os
1984
+
1985
+ return os.path.getsize(self._db_path)
1986
+
1987
+ def export_table_to_dict(self, table_name: str) -> list[dict]:
1988
+ """
1989
+ テーブル全体をdictのリストとして取得
1990
+
1991
+ Args:
1992
+ table_name: テーブル名
1993
+
1994
+ Returns:
1995
+ 全レコードのリスト
1996
+
1997
+ Example:
1998
+ >>> all_users = db.export_table_to_dict("users")
1999
+ """
2000
+ return self.query_with_pagination(table_name=table_name)
2001
+
2002
+ def import_from_dict_list(self, table_name: str, data_list: list[dict]) -> int:
2003
+ """
2004
+ dictのリストからテーブルに一括挿入
2005
+
2006
+ Args:
2007
+ table_name: テーブル名
2008
+ data_list: 挿入するデータのリスト
2009
+
2010
+ Returns:
2011
+ 挿入された行数
2012
+
2013
+ Example:
2014
+ >>> users = [
2015
+ ... {"name": "Alice", "age": 25},
2016
+ ... {"name": "Bob", "age": 30}
2017
+ ... ]
2018
+ >>> count = db.import_from_dict_list("users", users)
2019
+ """
2020
+ if not data_list:
2021
+ return 0
2022
+
2023
+ safe_table_name = self._sanitize_identifier(table_name)
2024
+
2025
+ # 最初のdictからカラム名を取得
2026
+ columns = list(data_list[0].keys())
2027
+ safe_columns = [self._sanitize_identifier(col) for col in columns]
2028
+ placeholders = ", ".join(["?"] * len(columns))
2029
+ columns_sql = ", ".join(safe_columns)
2030
+ sql = f"INSERT INTO {safe_table_name} ({columns_sql}) VALUES ({placeholders})" # nosec
2031
+
2032
+ # 各dictから値を抽出
2033
+ parameters_list = []
2034
+ for data in data_list:
2035
+ values = [data.get(col) for col in columns]
2036
+ parameters_list.append(tuple(values))
2037
+
2038
+ self.execute_many(sql, parameters_list)
2039
+ return len(data_list)
2040
+
2041
+ def get_last_insert_rowid(self) -> int:
2042
+ """
2043
+ 最後に挿入されたROWIDを取得
2044
+
2045
+ Returns:
2046
+ 最後に挿入されたROWID
2047
+
2048
+ Example:
2049
+ >>> db.sql_insert("users", {"name": "Alice"})
2050
+ >>> rowid = db.get_last_insert_rowid()
2051
+ """
2052
+ return self._connection.last_insert_rowid()
2053
+
2054
+ def pragma(self, pragma_name: str, value: Any = None) -> Any:
2055
+ """
2056
+ PRAGMA設定の取得/設定
2057
+
2058
+ Args:
2059
+ pragma_name: PRAGMA名
2060
+ value: 設定値(Noneの場合は取得のみ)
2061
+
2062
+ Returns:
2063
+ valueがNoneの場合は現在の値、そうでない場合はNone
2064
+
2065
+ Example:
2066
+ >>> # 取得
2067
+ >>> mode = db.pragma("journal_mode")
2068
+
2069
+ >>> # 設定
2070
+ >>> db.pragma("foreign_keys", 1)
2071
+ """
2072
+ # Whitelist of allowed PRAGMA commands for security
2073
+ ALLOWED_PRAGMAS = {
2074
+ "foreign_keys",
2075
+ "journal_mode",
2076
+ "synchronous",
2077
+ "cache_size",
2078
+ "temp_store",
2079
+ "locking_mode",
2080
+ "auto_vacuum",
2081
+ "page_size",
2082
+ "encoding",
2083
+ "user_version",
2084
+ "schema_version",
2085
+ "wal_autocheckpoint",
2086
+ "busy_timeout",
2087
+ "query_only",
2088
+ "recursive_triggers",
2089
+ "secure_delete",
2090
+ "table_info",
2091
+ "index_list",
2092
+ "index_info",
2093
+ "database_list",
2094
+ }
2095
+
2096
+ if pragma_name not in ALLOWED_PRAGMAS:
2097
+ raise ValueError(f"PRAGMA '{pragma_name}' is not allowed. Allowed: {', '.join(sorted(ALLOWED_PRAGMAS))}")
2098
+
2099
+ if value is None:
2100
+ cursor = self.execute(f"PRAGMA {pragma_name}")
2101
+ result = cursor.fetchone()
2102
+ return result[0] if result else None
2103
+ else:
2104
+ # Validate value is safe (int, float, or simple string)
2105
+ if not isinstance(value, (int, float, str)):
2106
+ raise ValueError(f"PRAGMA value must be int, float, or str, got {type(value).__name__}")
2107
+
2108
+ # For string values, validate to prevent SQL injection
2109
+ if isinstance(value, str):
2110
+ # Only allow alphanumeric, underscore, dash, and dots for string values
2111
+ if not re.match(r"^[\w\-\.]+$", value):
2112
+ raise ValueError(
2113
+ "PRAGMA string value must contain only alphanumeric, underscore, dash, or dot characters"
2114
+ )
2115
+ value_str = f"'{value}'"
2116
+ else:
2117
+ value_str = str(value)
2118
+
2119
+ self.execute(f"PRAGMA {pragma_name} = {value_str}")
2120
+ return None
2121
+
2122
+ # ==================== Transaction Control ====================
2123
+
2124
+ def begin_transaction(self) -> None:
2125
+ """
2126
+ トランザクションを開始
2127
+
2128
+ Note:
2129
+ SQLiteはネストされたトランザクションをサポートしていません。
2130
+ 既にトランザクション中の場合、NanaSQLiteTransactionErrorが発生します。
2131
+
2132
+ Raises:
2133
+ NanaSQLiteTransactionError: 既にトランザクション中の場合
2134
+ NanaSQLiteConnectionError: 接続が閉じられている場合
2135
+ NanaSQLiteDatabaseError: トランザクション開始に失敗した場合
2136
+
2137
+ Example:
2138
+ >>> db.begin_transaction()
2139
+ >>> try:
2140
+ ... db.sql_insert("users", {"name": "Alice"})
2141
+ ... db.sql_insert("users", {"name": "Bob"})
2142
+ ... db.commit()
2143
+ ... except:
2144
+ ... db.rollback()
2145
+ """
2146
+ self._check_connection()
2147
+
2148
+ if self._in_transaction:
2149
+ raise NanaSQLiteTransactionError(
2150
+ "Transaction already in progress. "
2151
+ "SQLite does not support nested transactions. "
2152
+ "Please commit or rollback the current transaction first."
2153
+ )
2154
+
2155
+ try:
2156
+ self.execute("BEGIN IMMEDIATE")
2157
+ self._in_transaction = True
2158
+ self._transaction_depth = 1
2159
+ except Exception as e:
2160
+ raise NanaSQLiteDatabaseError(
2161
+ f"Failed to begin transaction: {e}", original_error=e if isinstance(e, apsw.Error) else None
2162
+ ) from e
2163
+
2164
+ def commit(self) -> None:
2165
+ """
2166
+ トランザクションをコミット
2167
+
2168
+ Raises:
2169
+ NanaSQLiteTransactionError: トランザクション外でコミットを試みた場合
2170
+ NanaSQLiteConnectionError: 接続が閉じられている場合
2171
+ NanaSQLiteDatabaseError: コミットに失敗した場合
2172
+ """
2173
+ self._check_connection()
2174
+
2175
+ if not self._in_transaction:
2176
+ raise NanaSQLiteTransactionError(
2177
+ "No transaction in progress. Call begin_transaction() first or use the transaction() context manager."
2178
+ )
2179
+
2180
+ try:
2181
+ self.execute("COMMIT")
2182
+ self._in_transaction = False
2183
+ self._transaction_depth = 0
2184
+ except Exception as e:
2185
+ # コミット失敗時は状態を維持(ロールバックが必要)
2186
+ raise NanaSQLiteDatabaseError(
2187
+ f"Failed to commit transaction: {e}", original_error=e if isinstance(e, apsw.Error) else None
2188
+ ) from e
2189
+
2190
+ def rollback(self) -> None:
2191
+ """
2192
+ トランザクションをロールバック
2193
+
2194
+ Raises:
2195
+ NanaSQLiteTransactionError: トランザクション外でロールバックを試みた場合
2196
+ NanaSQLiteConnectionError: 接続が閉じられている場合
2197
+ NanaSQLiteDatabaseError: ロールバックに失敗した場合
2198
+ """
2199
+ self._check_connection()
2200
+
2201
+ if not self._in_transaction:
2202
+ raise NanaSQLiteTransactionError(
2203
+ "No transaction in progress. Call begin_transaction() first or use the transaction() context manager."
2204
+ )
2205
+
2206
+ try:
2207
+ self.execute("ROLLBACK")
2208
+ self._in_transaction = False
2209
+ self._transaction_depth = 0
2210
+ except Exception as e:
2211
+ # ロールバック失敗は深刻なので状態をリセット
2212
+ self._in_transaction = False
2213
+ self._transaction_depth = 0
2214
+ raise NanaSQLiteDatabaseError(
2215
+ f"Failed to rollback transaction: {e}", original_error=e if isinstance(e, apsw.Error) else None
2216
+ ) from e
2217
+
2218
+ def in_transaction(self) -> bool:
2219
+ """
2220
+ 現在トランザクション中かどうかを返す
2221
+
2222
+ Returns:
2223
+ bool: トランザクション中の場合True
2224
+
2225
+ Example:
2226
+ >>> db.begin_transaction()
2227
+ >>> print(db.in_transaction()) # True
2228
+ >>> db.commit()
2229
+ >>> print(db.in_transaction()) # False
2230
+ """
2231
+ return self._in_transaction
2232
+
2233
+ def transaction(self):
2234
+ """
2235
+ トランザクションのコンテキストマネージャ
2236
+
2237
+ コンテキストマネージャ内で例外が発生しない場合は自動的にコミット、
2238
+ 例外が発生した場合は自動的にロールバックします。
2239
+
2240
+ Raises:
2241
+ NanaSQLiteTransactionError: 既にトランザクション中の場合
2242
+
2243
+ Example:
2244
+ >>> with db.transaction():
2245
+ ... db.sql_insert("users", {"name": "Alice"})
2246
+ ... db.sql_insert("users", {"name": "Bob"})
2247
+ ... # 自動的にコミット、例外時はロールバック
2248
+ """
2249
+ return _TransactionContext(self)
2250
+
2251
+ def table(
2252
+ self,
2253
+ table_name: str,
2254
+ cache_strategy: CacheType | Literal["unbounded", "lru"] | None = None,
2255
+ cache_size: int | None = None,
2256
+ ):
2257
+ """
2258
+ サブテーブル用のNanaSQLiteインスタンスを取得
2259
+
2260
+ 新しいインスタンスを作成しますが、SQLite接続とロックは共有します。
2261
+ これにより、複数のテーブルインスタンスが同じ接続を使用して
2262
+ スレッドセーフに動作します。
2263
+
2264
+ Args:
2265
+ table_name: テーブル名
2266
+ cache_strategy: このテーブル用のキャッシュ戦略 (デフォルト: 親と同じ)
2267
+ cache_size: このテーブル用のキャッシュサイズ (デフォルト: 親と同じ)
2268
+
2269
+ ⚠️ 重要な注意事項:
2270
+ - 同じテーブルに対して複数のインスタンスを作成しないでください
2271
+ 各インスタンスは独立したキャッシュを持つため、キャッシュ不整合が発生します
2272
+ - 推奨: テーブルインスタンスを変数に保存して再利用してください
2273
+
2274
+ 非推奨:
2275
+ sub1 = db.table("users")
2276
+ sub2 = db.table("users") # キャッシュ不整合の原因
2277
+
2278
+ 推奨:
2279
+ users_db = db.table("users")
2280
+ # users_dbを使い回す
2281
+
2282
+ :param table_name: テーブル名
2283
+ :return NanaSQLite: 新しいテーブルインスタンス
2284
+
2285
+ Raises:
2286
+ NanaSQLiteConnectionError: 接続が閉じられている場合
2287
+
2288
+ Example:
2289
+ >>> with NanaSQLite("app.db", table="main") as main_db:
2290
+ ... users_db = main_db.table("users")
2291
+ ... products_db = main_db.table("products")
2292
+ ... users_db["user1"] = {"name": "Alice"}
2293
+ ... products_db["prod1"] = {"name": "Laptop"}
2294
+ """
2295
+ self._check_connection()
2296
+
2297
+ # 指定がなければデフォルト(UNBOUNDED)
2298
+ strat = cache_strategy if cache_strategy is not None else CacheType.UNBOUNDED
2299
+ size = cache_size
2300
+
2301
+ child = NanaSQLite(
2302
+ self._db_path,
2303
+ table=table_name,
2304
+ cache_strategy=strat,
2305
+ cache_size=size,
2306
+ _shared_connection=self._connection,
2307
+ _shared_lock=self._lock,
2308
+ )
2309
+
2310
+ # If the parent is the connection owner, the child is not.
2311
+ # This ensures only one instance (the owner) attempts to close the connection.
2312
+ if self._is_connection_owner:
2313
+ child._is_connection_owner = False
2314
+
2315
+ # 子インスタンスを追跡 (WeakSetに直接オブジェクトを追加すると、WeakSetが弱参照を保持する)
2316
+ self._child_instances.add(child)
2317
+
2318
+ return child
2319
+
2320
+
2321
+ class _TransactionContext:
2322
+ """トランザクションのコンテキストマネージャ"""
2323
+
2324
+ def __init__(self, db: NanaSQLite):
2325
+ self.db = db
2326
+
2327
+ def __enter__(self):
2328
+ self.db.begin_transaction()
2329
+ return self.db
2330
+
2331
+ def __exit__(self, exc_type, exc_val, exc_tb):
2332
+ if exc_type is None:
2333
+ self.db.commit()
2334
+ else:
2335
+ self.db.rollback()
2336
+ return False