compair-core 0.4.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compair_core/__init__.py +8 -0
- compair_core/api.py +3598 -0
- compair_core/compair/__init__.py +57 -0
- compair_core/compair/celery_app.py +31 -0
- compair_core/compair/default_groups.py +14 -0
- compair_core/compair/embeddings.py +141 -0
- compair_core/compair/feedback.py +368 -0
- compair_core/compair/logger.py +29 -0
- compair_core/compair/main.py +276 -0
- compair_core/compair/models.py +453 -0
- compair_core/compair/schema.py +146 -0
- compair_core/compair/tasks.py +106 -0
- compair_core/compair/utils.py +42 -0
- compair_core/compair_email/__init__.py +0 -0
- compair_core/compair_email/email.py +6 -0
- compair_core/compair_email/email_core.py +15 -0
- compair_core/compair_email/templates.py +6 -0
- compair_core/compair_email/templates_core.py +32 -0
- compair_core/db.py +64 -0
- compair_core/server/__init__.py +0 -0
- compair_core/server/app.py +97 -0
- compair_core/server/deps.py +77 -0
- compair_core/server/local_model/__init__.py +1 -0
- compair_core/server/local_model/app.py +87 -0
- compair_core/server/local_model/ocr.py +107 -0
- compair_core/server/providers/__init__.py +0 -0
- compair_core/server/providers/console_mailer.py +9 -0
- compair_core/server/providers/contracts.py +66 -0
- compair_core/server/providers/http_ocr.py +60 -0
- compair_core/server/providers/local_storage.py +28 -0
- compair_core/server/providers/noop_analytics.py +7 -0
- compair_core/server/providers/noop_billing.py +30 -0
- compair_core/server/providers/noop_ocr.py +10 -0
- compair_core/server/routers/__init__.py +0 -0
- compair_core/server/routers/capabilities.py +46 -0
- compair_core/server/settings.py +66 -0
- compair_core-0.4.12.dist-info/METADATA +136 -0
- compair_core-0.4.12.dist-info/RECORD +41 -0
- compair_core-0.4.12.dist-info/WHEEL +5 -0
- compair_core-0.4.12.dist-info/licenses/LICENSE +674 -0
- compair_core-0.4.12.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import secrets
|
|
4
|
+
from datetime import datetime, timedelta, timezone
|
|
5
|
+
|
|
6
|
+
from sqlalchemy.orm import Session
|
|
7
|
+
|
|
8
|
+
from .models import Activity
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def chunk_text(text: str) -> list[str]:
|
|
12
|
+
chunks = text.split("\n\n")
|
|
13
|
+
chunks = [c.strip() for c in chunks]
|
|
14
|
+
return [c for c in chunks if c]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def generate_verification_token() -> tuple[str, datetime]:
|
|
18
|
+
token = secrets.token_urlsafe(32)
|
|
19
|
+
expiration = datetime.now(timezone.utc) + timedelta(hours=24)
|
|
20
|
+
return token, expiration
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def log_activity(
|
|
24
|
+
session: Session,
|
|
25
|
+
user_id: str,
|
|
26
|
+
group_id: str,
|
|
27
|
+
action: str,
|
|
28
|
+
object_id: str,
|
|
29
|
+
object_name: str,
|
|
30
|
+
object_type: str,
|
|
31
|
+
) -> None:
|
|
32
|
+
activity = Activity(
|
|
33
|
+
user_id=user_id,
|
|
34
|
+
group_id=group_id,
|
|
35
|
+
action=action,
|
|
36
|
+
object_id=object_id,
|
|
37
|
+
object_name=object_name,
|
|
38
|
+
object_type=object_type,
|
|
39
|
+
timestamp=datetime.now(timezone.utc),
|
|
40
|
+
)
|
|
41
|
+
session.add(activity)
|
|
42
|
+
session.commit()
|
|
File without changes
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
from redmail import EmailSender
|
|
4
|
+
|
|
5
|
+
EMAIL_HOST = f"{os.environ.get('EMAIL_HOST')}"
|
|
6
|
+
EMAIL_USER = f"{os.environ.get('EMAIL_USER')}"
|
|
7
|
+
EMAIL_PW = f"{os.environ.get('EMAIL_PW')}"
|
|
8
|
+
|
|
9
|
+
emailer = EmailSender(
|
|
10
|
+
host=EMAIL_HOST,
|
|
11
|
+
port=587,
|
|
12
|
+
username=EMAIL_USER,
|
|
13
|
+
password=EMAIL_PW,
|
|
14
|
+
use_starttls=True
|
|
15
|
+
)
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
"""Minimal email templates for the core edition."""
|
|
2
|
+
|
|
3
|
+
ACCOUNT_VERIFY_TEMPLATE = """
|
|
4
|
+
<p>Hi {{user_name}},</p>
|
|
5
|
+
<p>Please verify your Compair account by clicking the link below:</p>
|
|
6
|
+
<p><a href="{{verify_link}}">Verify my account</a></p>
|
|
7
|
+
<p>Thanks!</p>
|
|
8
|
+
""".strip()
|
|
9
|
+
|
|
10
|
+
PASSWORD_RESET_TEMPLATE = """
|
|
11
|
+
<p>We received a request to reset your password.</p>
|
|
12
|
+
<p>Your password reset code is: <strong>{{reset_code}}</strong></p>
|
|
13
|
+
""".strip()
|
|
14
|
+
|
|
15
|
+
GROUP_INVITATION_TEMPLATE = """
|
|
16
|
+
<p>{{inviter_name}} invited you to join the group {{group_name}}.</p>
|
|
17
|
+
<p><a href="{{invitation_link}}">Accept invitation</a></p>
|
|
18
|
+
""".strip()
|
|
19
|
+
|
|
20
|
+
GROUP_JOIN_TEMPLATE = """
|
|
21
|
+
<p>{{user_name}} has joined your group.</p>
|
|
22
|
+
""".strip()
|
|
23
|
+
|
|
24
|
+
INDIVIDUAL_INVITATION_TEMPLATE = """
|
|
25
|
+
<p>{{inviter_name}} invited you to Compair.</p>
|
|
26
|
+
<p><a href="{{referral_link}}">Join now</a></p>
|
|
27
|
+
""".strip()
|
|
28
|
+
|
|
29
|
+
REFERRAL_CREDIT_TEMPLATE = """
|
|
30
|
+
<p>Hi {{user_name}},</p>
|
|
31
|
+
<p>Great news! You now have {{referral_credits}} referral credits.</p>
|
|
32
|
+
""".strip()
|
compair_core/db.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
from sqlalchemy import Engine, create_engine
|
|
7
|
+
from sqlalchemy.orm import sessionmaker
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _build_engine() -> Engine:
|
|
11
|
+
"""Create the SQLAlchemy engine using the same precedence as the core package."""
|
|
12
|
+
explicit_url = (
|
|
13
|
+
os.getenv("COMPAIR_DATABASE_URL")
|
|
14
|
+
or os.getenv("COMPAIR_DB_URL")
|
|
15
|
+
or os.getenv("DATABASE_URL")
|
|
16
|
+
)
|
|
17
|
+
if explicit_url:
|
|
18
|
+
if explicit_url.startswith("sqlite:"):
|
|
19
|
+
return create_engine(explicit_url, connect_args={"check_same_thread": False})
|
|
20
|
+
return create_engine(explicit_url)
|
|
21
|
+
|
|
22
|
+
# Backwards compatibility with legacy Postgres env variables
|
|
23
|
+
db = os.getenv("DB")
|
|
24
|
+
db_user = os.getenv("DB_USER")
|
|
25
|
+
db_passw = os.getenv("DB_PASSW")
|
|
26
|
+
db_host = os.getenv("DB_URL")
|
|
27
|
+
|
|
28
|
+
if all([db, db_user, db_passw, db_host]):
|
|
29
|
+
return create_engine(
|
|
30
|
+
f"postgresql+psycopg2://{db_user}:{db_passw}@{db_host}/{db}",
|
|
31
|
+
pool_size=10,
|
|
32
|
+
max_overflow=0,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
# Local default: place an SQLite database inside COMPAIR_DB_DIR
|
|
36
|
+
db_dir = (
|
|
37
|
+
os.getenv("COMPAIR_DB_DIR")
|
|
38
|
+
or os.getenv("COMPAIR_SQLITE_DIR")
|
|
39
|
+
or os.path.join(Path.home(), ".compair-core", "data")
|
|
40
|
+
)
|
|
41
|
+
db_name = os.getenv("COMPAIR_DB_NAME") or os.getenv("COMPAIR_SQLITE_NAME") or "compair.db"
|
|
42
|
+
|
|
43
|
+
db_path = Path(db_dir).expanduser()
|
|
44
|
+
try:
|
|
45
|
+
db_path.mkdir(parents=True, exist_ok=True)
|
|
46
|
+
except OSError:
|
|
47
|
+
fallback_dir = Path(os.getcwd()) / "compair_data"
|
|
48
|
+
fallback_dir.mkdir(parents=True, exist_ok=True)
|
|
49
|
+
db_path = fallback_dir
|
|
50
|
+
|
|
51
|
+
sqlite_path = db_path / db_name
|
|
52
|
+
return create_engine(
|
|
53
|
+
f"sqlite:///{sqlite_path}",
|
|
54
|
+
connect_args={"check_same_thread": False},
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
engine = _build_engine()
|
|
59
|
+
|
|
60
|
+
# Keep behavior identical to previous `Session = sessionmaker(engine)`
|
|
61
|
+
SessionLocal = sessionmaker(engine)
|
|
62
|
+
Session = SessionLocal
|
|
63
|
+
|
|
64
|
+
__all__ = ["engine", "SessionLocal", "Session"]
|
|
File without changes
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
"""FastAPI app factory supporting Core and Cloud editions."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from fastapi import FastAPI
|
|
5
|
+
|
|
6
|
+
from .deps import (
|
|
7
|
+
get_analytics,
|
|
8
|
+
get_billing,
|
|
9
|
+
get_mailer,
|
|
10
|
+
get_ocr,
|
|
11
|
+
get_settings_dependency,
|
|
12
|
+
get_storage,
|
|
13
|
+
)
|
|
14
|
+
from .providers.local_storage import LocalStorage
|
|
15
|
+
from .routers.capabilities import router as capabilities_router
|
|
16
|
+
from .settings import Settings, get_settings
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _normalize_edition(value: str) -> str:
|
|
20
|
+
return (value or "core").lower()
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def create_app(settings: Settings | None = None) -> FastAPI:
|
|
24
|
+
"""Instantiate the FastAPI application with edition-specific wiring."""
|
|
25
|
+
|
|
26
|
+
resolved_settings = settings or get_settings()
|
|
27
|
+
edition = _normalize_edition(resolved_settings.edition)
|
|
28
|
+
|
|
29
|
+
app = FastAPI(title="Compair API", version=resolved_settings.version)
|
|
30
|
+
|
|
31
|
+
from ..api import core_router, router as legacy_router
|
|
32
|
+
|
|
33
|
+
if edition == "cloud":
|
|
34
|
+
app.include_router(legacy_router)
|
|
35
|
+
else:
|
|
36
|
+
if resolved_settings.include_legacy_routes:
|
|
37
|
+
app.include_router(legacy_router)
|
|
38
|
+
else:
|
|
39
|
+
app.include_router(core_router)
|
|
40
|
+
app.include_router(capabilities_router)
|
|
41
|
+
|
|
42
|
+
# Share the resolved settings with request handlers
|
|
43
|
+
app.dependency_overrides[get_settings_dependency] = lambda: resolved_settings
|
|
44
|
+
|
|
45
|
+
if edition == "cloud":
|
|
46
|
+
try:
|
|
47
|
+
from compair_cloud.analytics.ga4 import GA4Analytics
|
|
48
|
+
from compair_cloud.billing.stripe_provider import StripeBilling
|
|
49
|
+
from compair_cloud.mailer.transactional import TransactionalMailer
|
|
50
|
+
from compair_cloud.ocr.claude_ocr import ClaudeOCR
|
|
51
|
+
from compair_cloud.storage.r2_storage import R2Storage
|
|
52
|
+
except ImportError as exc: # pragma: no cover - only triggered in misconfigured builds
|
|
53
|
+
raise RuntimeError(
|
|
54
|
+
"Cloud edition requires the private 'compair_cloud' package to be installed."
|
|
55
|
+
) from exc
|
|
56
|
+
|
|
57
|
+
storage_provider = R2Storage(
|
|
58
|
+
bucket=resolved_settings.r2_bucket,
|
|
59
|
+
cdn_base=resolved_settings.r2_cdn_base,
|
|
60
|
+
access_key=resolved_settings.r2_access_key,
|
|
61
|
+
secret_key=resolved_settings.r2_secret_key,
|
|
62
|
+
endpoint_url=resolved_settings.r2_endpoint_url,
|
|
63
|
+
)
|
|
64
|
+
billing_provider = StripeBilling(
|
|
65
|
+
stripe_key=resolved_settings.stripe_key,
|
|
66
|
+
endpoint_secret=resolved_settings.stripe_endpoint_secret,
|
|
67
|
+
)
|
|
68
|
+
ocr_provider = ClaudeOCR()
|
|
69
|
+
mailer_provider = TransactionalMailer()
|
|
70
|
+
|
|
71
|
+
analytics_provider = None
|
|
72
|
+
if resolved_settings.ga4_measurement_id and resolved_settings.ga4_api_secret:
|
|
73
|
+
analytics_provider = GA4Analytics(
|
|
74
|
+
measurement_id=resolved_settings.ga4_measurement_id,
|
|
75
|
+
api_secret=resolved_settings.ga4_api_secret,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
app.dependency_overrides[get_storage] = lambda sp=storage_provider: sp
|
|
79
|
+
app.dependency_overrides[get_billing] = lambda bp=billing_provider: bp
|
|
80
|
+
app.dependency_overrides[get_ocr] = lambda op=ocr_provider: op
|
|
81
|
+
if analytics_provider is not None:
|
|
82
|
+
app.dependency_overrides[get_analytics] = lambda ap=analytics_provider: ap
|
|
83
|
+
app.dependency_overrides[get_mailer] = lambda mp=mailer_provider: mp
|
|
84
|
+
object.__setattr__(resolved_settings, "ocr_enabled", True)
|
|
85
|
+
|
|
86
|
+
else:
|
|
87
|
+
storage_provider = LocalStorage(
|
|
88
|
+
base_dir=resolved_settings.local_upload_dir,
|
|
89
|
+
base_url=resolved_settings.local_upload_base_url,
|
|
90
|
+
)
|
|
91
|
+
app.dependency_overrides[get_storage] = lambda sp=storage_provider: sp
|
|
92
|
+
|
|
93
|
+
return app
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
# Uvicorn compatibility: allow ``uvicorn server.app:app``
|
|
97
|
+
app = create_app()
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""Dependency entry points for features that differ by edition."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from functools import lru_cache
|
|
5
|
+
|
|
6
|
+
from fastapi import Depends
|
|
7
|
+
|
|
8
|
+
from .providers.console_mailer import ConsoleMailer
|
|
9
|
+
from .providers.contracts import Analytics, BillingProvider, Mailer, OCRProvider, StorageProvider
|
|
10
|
+
from .providers.http_ocr import HTTPOCR
|
|
11
|
+
from .providers.local_storage import LocalStorage
|
|
12
|
+
from .providers.noop_analytics import NoopAnalytics
|
|
13
|
+
from .providers.noop_billing import NoopBilling
|
|
14
|
+
from .providers.noop_ocr import NoopOCR
|
|
15
|
+
from .settings import Settings, get_settings
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@lru_cache
|
|
19
|
+
def _local_storage_factory(base_dir: str, base_url: str) -> LocalStorage:
|
|
20
|
+
return LocalStorage(base_dir=base_dir, base_url=base_url)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@lru_cache
|
|
24
|
+
def _noop_billing() -> NoopBilling:
|
|
25
|
+
return NoopBilling()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@lru_cache
|
|
29
|
+
def _noop_ocr() -> NoopOCR:
|
|
30
|
+
return NoopOCR()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@lru_cache
|
|
34
|
+
def _http_ocr(endpoint: str, timeout: float) -> HTTPOCR:
|
|
35
|
+
return HTTPOCR(endpoint=endpoint, timeout=timeout)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@lru_cache
|
|
39
|
+
def _console_mailer() -> ConsoleMailer:
|
|
40
|
+
return ConsoleMailer()
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@lru_cache
|
|
44
|
+
def _noop_analytics() -> NoopAnalytics:
|
|
45
|
+
return NoopAnalytics()
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def get_settings_dependency() -> Settings:
|
|
49
|
+
"""Expose settings as a FastAPI dependency."""
|
|
50
|
+
return get_settings()
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def get_storage(
|
|
55
|
+
settings: Settings = Depends(get_settings_dependency),
|
|
56
|
+
) -> StorageProvider:
|
|
57
|
+
return _local_storage_factory(settings.local_upload_dir, settings.local_upload_base_url)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def get_billing() -> BillingProvider:
|
|
61
|
+
return _noop_billing()
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def get_ocr(
|
|
65
|
+
settings: Settings = Depends(get_settings_dependency),
|
|
66
|
+
) -> OCRProvider:
|
|
67
|
+
if settings.ocr_enabled and settings.ocr_endpoint:
|
|
68
|
+
return _http_ocr(settings.ocr_endpoint, settings.ocr_request_timeout)
|
|
69
|
+
return _noop_ocr()
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def get_mailer() -> Mailer:
|
|
73
|
+
return _console_mailer()
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def get_analytics() -> Analytics:
|
|
77
|
+
return _noop_analytics()
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Local model endpoints for the Core edition."""
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
"""Minimal FastAPI application serving local embedding and generation endpoints."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import hashlib
|
|
5
|
+
import os
|
|
6
|
+
from typing import List
|
|
7
|
+
|
|
8
|
+
from fastapi import FastAPI
|
|
9
|
+
from pydantic import BaseModel
|
|
10
|
+
|
|
11
|
+
app = FastAPI(title="Compair Local Model", version="0.1.0")
|
|
12
|
+
|
|
13
|
+
_DEFAULT_DIM = 384
|
|
14
|
+
_DIM_ENV = (
|
|
15
|
+
os.getenv("COMPAIR_EMBEDDING_DIM")
|
|
16
|
+
or os.getenv("COMPAIR_EMBEDDING_DIMENSION")
|
|
17
|
+
or os.getenv("COMPAIR_LOCAL_EMBED_DIM")
|
|
18
|
+
or str(_DEFAULT_DIM)
|
|
19
|
+
)
|
|
20
|
+
try:
|
|
21
|
+
EMBED_DIMENSION = int(_DIM_ENV)
|
|
22
|
+
except ValueError: # pragma: no cover - invalid configuration
|
|
23
|
+
EMBED_DIMENSION = _DEFAULT_DIM
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _hash_embedding(text: str, dimension: int = EMBED_DIMENSION) -> List[float]:
|
|
27
|
+
if not text:
|
|
28
|
+
text = " "
|
|
29
|
+
digest = hashlib.sha256(text.encode("utf-8", "ignore")).digest()
|
|
30
|
+
vector: List[float] = []
|
|
31
|
+
while len(vector) < dimension:
|
|
32
|
+
for byte in digest:
|
|
33
|
+
vector.append((byte / 255.0) * 2 - 1)
|
|
34
|
+
if len(vector) == dimension:
|
|
35
|
+
break
|
|
36
|
+
digest = hashlib.sha256(digest).digest()
|
|
37
|
+
return vector
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class EmbedRequest(BaseModel):
|
|
41
|
+
text: str
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class EmbedResponse(BaseModel):
|
|
45
|
+
embedding: List[float]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class GenerateRequest(BaseModel):
|
|
49
|
+
# Legacy format used by the CLI shim
|
|
50
|
+
system: str | None = None
|
|
51
|
+
prompt: str | None = None
|
|
52
|
+
verbosity: str | None = None
|
|
53
|
+
|
|
54
|
+
# Core API payload (document + references)
|
|
55
|
+
document: str | None = None
|
|
56
|
+
references: List[str] | None = None
|
|
57
|
+
length_instruction: str | None = None
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class GenerateResponse(BaseModel):
|
|
61
|
+
feedback: str
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@app.post("/embed", response_model=EmbedResponse)
|
|
65
|
+
def embed(request: EmbedRequest) -> EmbedResponse:
|
|
66
|
+
return EmbedResponse(embedding=_hash_embedding(request.text))
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@app.post("/generate", response_model=GenerateResponse)
|
|
70
|
+
def generate(request: GenerateRequest) -> GenerateResponse:
|
|
71
|
+
# Determine the main text input (document or prompt)
|
|
72
|
+
text_input = request.document or request.prompt or ""
|
|
73
|
+
text_input = text_input.strip()
|
|
74
|
+
|
|
75
|
+
if not text_input:
|
|
76
|
+
return GenerateResponse(feedback="NONE")
|
|
77
|
+
|
|
78
|
+
first_sentence = text_input.split("\n", 1)[0][:200]
|
|
79
|
+
verbosity = request.length_instruction or request.verbosity or "brief response"
|
|
80
|
+
ref_snippet = ""
|
|
81
|
+
if request.references:
|
|
82
|
+
top_ref = (request.references[0] or "").strip()
|
|
83
|
+
if top_ref:
|
|
84
|
+
ref_snippet = f" Reference: {top_ref[:160]}"
|
|
85
|
+
|
|
86
|
+
feedback = f"[local-feedback] {verbosity}: {first_sentence}{ref_snippet}".strip()
|
|
87
|
+
return GenerateResponse(feedback=feedback or "NONE")
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
"""Minimal OCR endpoint leveraging pytesseract when available."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import io
|
|
5
|
+
import os
|
|
6
|
+
from typing import Any, Dict
|
|
7
|
+
|
|
8
|
+
from fastapi import FastAPI, File, HTTPException, UploadFile
|
|
9
|
+
|
|
10
|
+
app = FastAPI(title="Compair Local OCR", version="0.1.0")
|
|
11
|
+
|
|
12
|
+
try: # Optional dependency
|
|
13
|
+
import pytesseract # type: ignore
|
|
14
|
+
from pytesseract import TesseractNotFoundError # type: ignore
|
|
15
|
+
from PIL import Image # type: ignore
|
|
16
|
+
except ImportError: # pragma: no cover - optional
|
|
17
|
+
pytesseract = None # type: ignore
|
|
18
|
+
TesseractNotFoundError = OSError # type: ignore
|
|
19
|
+
Image = None # type: ignore
|
|
20
|
+
|
|
21
|
+
try: # Optional: text extraction for PDFs
|
|
22
|
+
from pypdf import PdfReader # type: ignore
|
|
23
|
+
except ImportError: # pragma: no cover - optional
|
|
24
|
+
PdfReader = None # type: ignore
|
|
25
|
+
|
|
26
|
+
_OCR_FALLBACK = os.getenv("COMPAIR_LOCAL_OCR_FALLBACK", "text") # text | none
|
|
27
|
+
_TESSERACT_AVAILABLE = False
|
|
28
|
+
if pytesseract is not None and Image is not None:
|
|
29
|
+
try: # pragma: no cover - runtime probe
|
|
30
|
+
pytesseract.get_tesseract_version()
|
|
31
|
+
_TESSERACT_AVAILABLE = True
|
|
32
|
+
except (TesseractNotFoundError, OSError):
|
|
33
|
+
_TESSERACT_AVAILABLE = False
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _is_pdf(data: bytes) -> bool:
|
|
37
|
+
return data.startswith(b"%PDF")
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _extract_pdf_text(data: bytes) -> str:
|
|
41
|
+
if PdfReader is None:
|
|
42
|
+
return ""
|
|
43
|
+
try:
|
|
44
|
+
reader = PdfReader(io.BytesIO(data))
|
|
45
|
+
parts = []
|
|
46
|
+
for page in reader.pages:
|
|
47
|
+
text = page.extract_text() or ""
|
|
48
|
+
if text.strip():
|
|
49
|
+
parts.append(text)
|
|
50
|
+
return "\n".join(parts).strip()
|
|
51
|
+
except Exception:
|
|
52
|
+
return ""
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _fallback_text(data: bytes) -> str:
|
|
56
|
+
if _OCR_FALLBACK == "text":
|
|
57
|
+
try:
|
|
58
|
+
return data.decode("utf-8")
|
|
59
|
+
except UnicodeDecodeError:
|
|
60
|
+
return data.decode("latin-1", errors="ignore")
|
|
61
|
+
return ""
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _extract_text(data: bytes) -> tuple[str, bool]:
|
|
65
|
+
provider_used = False
|
|
66
|
+
if _is_pdf(data):
|
|
67
|
+
if PdfReader is not None:
|
|
68
|
+
provider_used = True
|
|
69
|
+
text = _extract_pdf_text(data)
|
|
70
|
+
if text:
|
|
71
|
+
return text, provider_used
|
|
72
|
+
if pytesseract is None or Image is None:
|
|
73
|
+
text = _fallback_text(data)
|
|
74
|
+
if text:
|
|
75
|
+
provider_used = True
|
|
76
|
+
return text, provider_used
|
|
77
|
+
|
|
78
|
+
if pytesseract is not None and Image is not None and _TESSERACT_AVAILABLE:
|
|
79
|
+
provider_used = True
|
|
80
|
+
try:
|
|
81
|
+
image = Image.open(io.BytesIO(data))
|
|
82
|
+
text = pytesseract.image_to_string(image)
|
|
83
|
+
if text:
|
|
84
|
+
return text, provider_used
|
|
85
|
+
except Exception:
|
|
86
|
+
pass
|
|
87
|
+
|
|
88
|
+
text = _fallback_text(data)
|
|
89
|
+
if text:
|
|
90
|
+
provider_used = True
|
|
91
|
+
return text, provider_used
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@app.post("/ocr-file")
|
|
95
|
+
async def ocr_file(file: UploadFile = File(...)) -> Dict[str, Any]:
|
|
96
|
+
payload = await file.read()
|
|
97
|
+
text, provider_available = _extract_text(payload)
|
|
98
|
+
if not provider_available:
|
|
99
|
+
if pytesseract is not None and not _TESSERACT_AVAILABLE:
|
|
100
|
+
detail = "Tesseract CLI is not installed or accessible. Install it (e.g., via your package manager) or run the container image, then retry."
|
|
101
|
+
else:
|
|
102
|
+
detail = "OCR extraction failed. Supported formats include PDFs (text-only) and common image types."
|
|
103
|
+
raise HTTPException(
|
|
104
|
+
status_code=501,
|
|
105
|
+
detail=detail,
|
|
106
|
+
)
|
|
107
|
+
return {"extracted_text": text or "NONE"}
|
|
File without changes
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
"""Console mailer used in Core builds to avoid delivering real email."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from typing import Iterable
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ConsoleMailer:
|
|
8
|
+
def send(self, subject: str, sender: str, receivers: Iterable[str], html: str) -> None:
|
|
9
|
+
print(f"[MAIL] {subject} -> {list(receivers)}")
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
"""Provider protocol definitions shared across editions."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import BinaryIO, Iterable, Mapping, Protocol
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass(slots=True)
|
|
9
|
+
class BillingSession:
|
|
10
|
+
"""Represents the result of creating a checkout session."""
|
|
11
|
+
|
|
12
|
+
id: str
|
|
13
|
+
url: str
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class StorageProvider(Protocol):
|
|
17
|
+
def put_file(self, key: str, fileobj: BinaryIO, content_type: str) -> str: ...
|
|
18
|
+
|
|
19
|
+
def get_file(self, key: str) -> tuple[BinaryIO, str]: ...
|
|
20
|
+
|
|
21
|
+
def build_url(self, key: str) -> str: ...
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class BillingProvider(Protocol):
|
|
25
|
+
def ensure_customer(self, *, user_email: str, user_id: str) -> str: ...
|
|
26
|
+
|
|
27
|
+
def create_checkout_session(
|
|
28
|
+
self,
|
|
29
|
+
*,
|
|
30
|
+
customer_id: str,
|
|
31
|
+
price_id: str,
|
|
32
|
+
qty: int,
|
|
33
|
+
success_url: str,
|
|
34
|
+
cancel_url: str,
|
|
35
|
+
metadata: Mapping[str, str] | None = None,
|
|
36
|
+
) -> BillingSession: ...
|
|
37
|
+
|
|
38
|
+
def retrieve_session(self, session_id: str) -> BillingSession: ...
|
|
39
|
+
|
|
40
|
+
def get_checkout_url(self, session_id: str) -> str: ...
|
|
41
|
+
|
|
42
|
+
def create_customer_portal(self, *, customer_id: str, return_url: str) -> str: ...
|
|
43
|
+
|
|
44
|
+
def create_coupon(self, amount: int) -> str: ...
|
|
45
|
+
|
|
46
|
+
def apply_coupon(self, *, customer_id: str, coupon_id: str) -> None: ...
|
|
47
|
+
|
|
48
|
+
def construct_event(self, payload: bytes, signature: str | None) -> Mapping[str, object]: ...
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class OCRProvider(Protocol):
|
|
52
|
+
def submit(
|
|
53
|
+
self, *, user_id: str, filename: str, data: bytes, document_id: str | None
|
|
54
|
+
) -> str: ...
|
|
55
|
+
|
|
56
|
+
def status(self, task_id: str) -> Mapping[str, object]: ...
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class Mailer(Protocol):
|
|
60
|
+
def send(self, subject: str, sender: str, receivers: Iterable[str], html: str) -> None: ...
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class Analytics(Protocol):
|
|
64
|
+
def track(
|
|
65
|
+
self, event_name: str, user_id: str, params: Mapping[str, object] | None = None
|
|
66
|
+
) -> None: ...
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import threading
|
|
4
|
+
import uuid
|
|
5
|
+
from typing import Dict
|
|
6
|
+
|
|
7
|
+
import requests
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class HTTPOCR:
|
|
11
|
+
"""Simple OCR provider that forwards uploads to an HTTP endpoint."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, endpoint: str, *, timeout: float = 30.0) -> None:
|
|
14
|
+
self.endpoint = endpoint
|
|
15
|
+
self.timeout = timeout
|
|
16
|
+
self._lock = threading.Lock()
|
|
17
|
+
self._results: Dict[str, dict[str, object]] = {}
|
|
18
|
+
|
|
19
|
+
def submit(
|
|
20
|
+
self,
|
|
21
|
+
*,
|
|
22
|
+
user_id: str,
|
|
23
|
+
filename: str,
|
|
24
|
+
data: bytes,
|
|
25
|
+
document_id: str | None,
|
|
26
|
+
) -> str:
|
|
27
|
+
files = {"file": (filename, data)}
|
|
28
|
+
response = requests.post(self.endpoint, files=files, timeout=self.timeout)
|
|
29
|
+
response.raise_for_status()
|
|
30
|
+
payload: dict[str, object]
|
|
31
|
+
try:
|
|
32
|
+
payload = response.json()
|
|
33
|
+
except ValueError:
|
|
34
|
+
payload = {"extracted_text": response.text}
|
|
35
|
+
|
|
36
|
+
extracted = (
|
|
37
|
+
payload.get("extracted_text")
|
|
38
|
+
or payload.get("text")
|
|
39
|
+
or payload.get("content")
|
|
40
|
+
)
|
|
41
|
+
if not extracted:
|
|
42
|
+
raise RuntimeError("OCR endpoint returned an empty response.")
|
|
43
|
+
|
|
44
|
+
task_id = str(uuid.uuid4())
|
|
45
|
+
result = {
|
|
46
|
+
"status": "completed",
|
|
47
|
+
"extracted_text": extracted,
|
|
48
|
+
"document_id": document_id,
|
|
49
|
+
"user_id": user_id,
|
|
50
|
+
}
|
|
51
|
+
with self._lock:
|
|
52
|
+
self._results[task_id] = result
|
|
53
|
+
return task_id
|
|
54
|
+
|
|
55
|
+
def status(self, task_id: str) -> dict[str, object]:
|
|
56
|
+
with self._lock:
|
|
57
|
+
result = self._results.get(task_id)
|
|
58
|
+
if result is None:
|
|
59
|
+
return {"status": "unknown", "task_id": task_id}
|
|
60
|
+
return result
|