compair-core 0.3.7__py3-none-any.whl → 0.3.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compair_core/__init__.py +8 -0
- compair_core/api.py +3357 -0
- {compair → compair_core/compair}/__init__.py +10 -14
- {compair_core-0.3.7.dist-info → compair_core-0.3.8.dist-info}/METADATA +2 -2
- compair_core-0.3.8.dist-info/RECORD +38 -0
- compair_core-0.3.8.dist-info/top_level.txt +1 -0
- compair_core-0.3.7.dist-info/RECORD +0 -36
- compair_core-0.3.7.dist-info/top_level.txt +0 -3
- {compair → compair_core/compair}/celery_app.py +0 -0
- {compair → compair_core/compair}/default_groups.py +0 -0
- {compair → compair_core/compair}/embeddings.py +0 -0
- {compair → compair_core/compair}/feedback.py +0 -0
- {compair → compair_core/compair}/logger.py +0 -0
- {compair → compair_core/compair}/main.py +0 -0
- {compair → compair_core/compair}/models.py +0 -0
- {compair → compair_core/compair}/schema.py +0 -0
- {compair → compair_core/compair}/tasks.py +0 -0
- {compair → compair_core/compair}/utils.py +0 -0
- {compair_email → compair_core/compair_email}/__init__.py +0 -0
- {compair_email → compair_core/compair_email}/email.py +0 -0
- {compair_email → compair_core/compair_email}/email_core.py +0 -0
- {compair_email → compair_core/compair_email}/templates.py +0 -0
- {compair_email → compair_core/compair_email}/templates_core.py +0 -0
- {server → compair_core/server}/__init__.py +0 -0
- {server → compair_core/server}/app.py +0 -0
- {server → compair_core/server}/deps.py +0 -0
- {server → compair_core/server}/local_model/__init__.py +0 -0
- {server → compair_core/server}/local_model/app.py +0 -0
- {server → compair_core/server}/providers/__init__.py +0 -0
- {server → compair_core/server}/providers/console_mailer.py +0 -0
- {server → compair_core/server}/providers/contracts.py +0 -0
- {server → compair_core/server}/providers/local_storage.py +0 -0
- {server → compair_core/server}/providers/noop_analytics.py +0 -0
- {server → compair_core/server}/providers/noop_billing.py +0 -0
- {server → compair_core/server}/providers/noop_ocr.py +0 -0
- {server → compair_core/server}/routers/__init__.py +0 -0
- {server → compair_core/server}/routers/capabilities.py +0 -0
- {server → compair_core/server}/settings.py +0 -0
- {compair_core-0.3.7.dist-info → compair_core-0.3.8.dist-info}/WHEEL +0 -0
- {compair_core-0.3.7.dist-info → compair_core-0.3.8.dist-info}/licenses/LICENSE +0 -0
compair_core/api.py
ADDED
|
@@ -0,0 +1,3357 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import os
|
|
3
|
+
import re
|
|
4
|
+
import requests
|
|
5
|
+
import secrets
|
|
6
|
+
import threading
|
|
7
|
+
import time
|
|
8
|
+
from datetime import datetime, timedelta, timezone
|
|
9
|
+
from typing import Any, Mapping, Optional, Tuple
|
|
10
|
+
|
|
11
|
+
import httpx
|
|
12
|
+
import psutil
|
|
13
|
+
from celery.result import AsyncResult
|
|
14
|
+
from fastapi import APIRouter, Body, Depends, File, Form, Header, HTTPException, Query, Request, UploadFile
|
|
15
|
+
from fastapi.responses import HTMLResponse, RedirectResponse, StreamingResponse
|
|
16
|
+
from sqlalchemy import distinct, func, select, or_
|
|
17
|
+
from sqlalchemy.orm import joinedload, Session
|
|
18
|
+
|
|
19
|
+
from .server.deps import get_analytics, get_billing, get_ocr, get_settings_dependency, get_storage
|
|
20
|
+
from .server.providers.contracts import Analytics, BillingProvider, OCRProvider, StorageProvider
|
|
21
|
+
from .server.settings import Settings
|
|
22
|
+
|
|
23
|
+
from . import compair
|
|
24
|
+
from .compair import models, schema
|
|
25
|
+
from .compair.embeddings import create_embedding, Embedder
|
|
26
|
+
from .compair.logger import log_event
|
|
27
|
+
from .compair.utils import chunk_text, generate_verification_token, log_activity
|
|
28
|
+
from .compair_email.email import emailer, EMAIL_USER
|
|
29
|
+
from .compair_email.templates import (
|
|
30
|
+
ACCOUNT_VERIFY_TEMPLATE,
|
|
31
|
+
GROUP_INVITATION_TEMPLATE,
|
|
32
|
+
GROUP_JOIN_TEMPLATE,
|
|
33
|
+
INDIVIDUAL_INVITATION_TEMPLATE,
|
|
34
|
+
PASSWORD_RESET_TEMPLATE,
|
|
35
|
+
REFERRAL_CREDIT_TEMPLATE
|
|
36
|
+
)
|
|
37
|
+
from .compair.tasks import process_document_task as process_document_celery, send_feature_announcement_task, send_deactivate_request_email, send_help_request_email
|
|
38
|
+
|
|
39
|
+
import redis
|
|
40
|
+
|
|
41
|
+
redis_url = os.environ.get("REDIS_URL")
|
|
42
|
+
redis_client = redis.Redis.from_url(redis_url)
|
|
43
|
+
#from compair.main import process_document
|
|
44
|
+
|
|
45
|
+
router = APIRouter()
|
|
46
|
+
WEB_URL = os.environ.get("WEB_URL")
|
|
47
|
+
ADMIN_API_KEY = os.environ.get("ADMIN_API_KEY")
|
|
48
|
+
|
|
49
|
+
CLOUDFLARE_IMAGES_ACCOUNT = os.environ.get("CLOUDFLARE_IMAGES_ACCOUNT")
|
|
50
|
+
CLOUDFLARE_IMAGES_URL_ACCOUNT = os.environ.get("CLOUDFLARE_IMAGES_URL_ACCOUNT")
|
|
51
|
+
CLOUDFLARE_IMAGES_BASE_URL = f"https://imagedelivery.net/{CLOUDFLARE_IMAGES_URL_ACCOUNT}"
|
|
52
|
+
CLOUDFLARE_IMAGES_TOKEN = os.environ.get("CLOUDFLARE_IMAGES_TOKEN")
|
|
53
|
+
CLOUDFLARE_IMAGES_UPLOAD_URL = f"https://api.cloudflare.com/client/v4/accounts/{CLOUDFLARE_IMAGES_ACCOUNT}/images/v1"
|
|
54
|
+
|
|
55
|
+
GA4_MEASUREMENT_ID = os.getenv("GA4_MEASUREMENT_ID")
|
|
56
|
+
GA4_API_SECRET = os.getenv("GA4_API_SECRET")
|
|
57
|
+
|
|
58
|
+
IS_CLOUD = os.getenv("COMPAIR_EDITION", "core").lower() == "cloud"
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def require_cloud(feature: str) -> None:
|
|
62
|
+
if not IS_CLOUD:
|
|
63
|
+
raise HTTPException(status_code=501, detail=f"{feature} is only available in the Compair Cloud edition.")
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _user_plan(user: models.User) -> str:
|
|
67
|
+
return getattr(user, "plan", "free") or "free"
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _user_team(user: models.User):
|
|
71
|
+
return getattr(user, "team", None)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _trial_expiration(user: models.User) -> datetime | None:
|
|
75
|
+
return getattr(user, "trial_expiration_date", None)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
HAS_TEAM = hasattr(models, "Team")
|
|
79
|
+
HAS_ACTIVITY = hasattr(models, "Activity")
|
|
80
|
+
HAS_REFERRALS = hasattr(models.User, "referral_code")
|
|
81
|
+
HAS_BILLING = hasattr(models.User, "stripe_customer_id")
|
|
82
|
+
HAS_TRIALS = hasattr(models.User, "trial_expiration_date")
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def require_feature(flag: bool, feature: str) -> None:
|
|
86
|
+
if not flag:
|
|
87
|
+
raise HTTPException(status_code=501, detail=f"{feature} is only available in the Compair Cloud edition.")
|
|
88
|
+
|
|
89
|
+
def get_current_user(auth_token: str = Header(...)):
|
|
90
|
+
with compair.Session() as session:
|
|
91
|
+
user_session = session.query(models.Session).filter(models.Session.id == auth_token).first()
|
|
92
|
+
if not user_session:
|
|
93
|
+
raise HTTPException(status_code=401, detail="Invalid or expired session token")
|
|
94
|
+
# Ensure datetime_valid_until is timezone-aware
|
|
95
|
+
valid_until = user_session.datetime_valid_until
|
|
96
|
+
if valid_until.tzinfo is None:
|
|
97
|
+
valid_until = valid_until.replace(tzinfo=timezone.utc)
|
|
98
|
+
if valid_until < datetime.now(timezone.utc):
|
|
99
|
+
raise HTTPException(status_code=401, detail="Invalid or expired session token")
|
|
100
|
+
user = session.query(
|
|
101
|
+
models.User
|
|
102
|
+
).filter(
|
|
103
|
+
models.User.user_id == user_session.user_id
|
|
104
|
+
).options(
|
|
105
|
+
joinedload(models.User.groups)
|
|
106
|
+
).first()
|
|
107
|
+
if not user:
|
|
108
|
+
raise HTTPException(status_code=404, detail="User not found")
|
|
109
|
+
return user
|
|
110
|
+
|
|
111
|
+
def get_current_user_with_access_to_doc(
|
|
112
|
+
document_id: str,
|
|
113
|
+
current_user: models.User = Depends(get_current_user)
|
|
114
|
+
) -> models.User:
|
|
115
|
+
with compair.Session() as session:
|
|
116
|
+
doc = session.query(models.Document).filter(models.Document.document_id == document_id).first()
|
|
117
|
+
if not doc:
|
|
118
|
+
raise HTTPException(status_code=404, detail="Document not found")
|
|
119
|
+
# Allow if user is author
|
|
120
|
+
if doc.author_id == current_user.user_id:
|
|
121
|
+
return current_user
|
|
122
|
+
# Allow if user is in any group that has access to the document
|
|
123
|
+
doc_group_ids = {g.group_id for g in doc.groups}
|
|
124
|
+
user_group_ids = {g.group_id for g in current_user.groups}
|
|
125
|
+
if doc_group_ids & user_group_ids:
|
|
126
|
+
return current_user
|
|
127
|
+
# Optionally, allow if document is published
|
|
128
|
+
if doc.is_published:
|
|
129
|
+
return current_user
|
|
130
|
+
raise HTTPException(status_code=403, detail="Not authorized to access this document")
|
|
131
|
+
|
|
132
|
+
def log_service_resource_metrics(service_name="backend"):
|
|
133
|
+
def log():
|
|
134
|
+
try:
|
|
135
|
+
p = psutil.Process(os.getpid())
|
|
136
|
+
mem_mb = round(p.memory_info().rss / 1024 / 1024, 2)
|
|
137
|
+
cpu_percent = p.cpu_percent(interval=1)
|
|
138
|
+
log_event("service_resource", service=service_name, memory_mb=mem_mb, cpu_percent=cpu_percent)
|
|
139
|
+
except Exception as e:
|
|
140
|
+
print(f"[Resource Log Error] {e}")
|
|
141
|
+
finally:
|
|
142
|
+
# Re-schedule logging in 5 minutes
|
|
143
|
+
threading.Timer(300, log).start()
|
|
144
|
+
log()
|
|
145
|
+
|
|
146
|
+
log_service_resource_metrics(service_name="backend") # or "frontend"
|
|
147
|
+
|
|
148
|
+
# Run via: fastapi dev api.py
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
@router.post("/login")
|
|
152
|
+
def login(request: schema.LoginRequest) -> dict:
|
|
153
|
+
with compair.Session() as session:
|
|
154
|
+
user = session.query(models.User).filter(models.User.username == request.username).first()
|
|
155
|
+
print("PW yo: {request.password}")
|
|
156
|
+
if not user or not user.check_password(request.password):
|
|
157
|
+
raise HTTPException(status_code=401, detail="Invalid credentials")
|
|
158
|
+
if user.status == 'inactive':
|
|
159
|
+
raise HTTPException(status_code=403, detail="User account is not verified")
|
|
160
|
+
now = datetime.now(tz=timezone.utc)
|
|
161
|
+
user_session = models.Session(
|
|
162
|
+
id=secrets.token_urlsafe(),
|
|
163
|
+
user_id=user.user_id,
|
|
164
|
+
datetime_created=now,
|
|
165
|
+
datetime_valid_until=now + timedelta(days=1),
|
|
166
|
+
)
|
|
167
|
+
session.add(user_session)
|
|
168
|
+
session.commit()
|
|
169
|
+
return {
|
|
170
|
+
"user_id": user.user_id,
|
|
171
|
+
"username": user.username,
|
|
172
|
+
"name": user.name,
|
|
173
|
+
"status": user.status,
|
|
174
|
+
"role": user.role,
|
|
175
|
+
"auth_token": user_session.id, # Return the session token here
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
@router.get("/username_exists")
|
|
180
|
+
def username_exists(username: str) -> dict:
|
|
181
|
+
with compair.Session() as session:
|
|
182
|
+
exists = session.query(models.User).filter(models.User.username == username).first() is not None
|
|
183
|
+
return {"exists": exists}
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
@router.get("/load_user")
|
|
187
|
+
def load_user(
|
|
188
|
+
username: str,
|
|
189
|
+
) -> schema.User | None:
|
|
190
|
+
with compair.Session() as session:
|
|
191
|
+
q = select(models.User).filter(
|
|
192
|
+
models.User.username.match(username)
|
|
193
|
+
)
|
|
194
|
+
user = session.execute(q).fetchone()
|
|
195
|
+
if user is None:
|
|
196
|
+
return
|
|
197
|
+
user = user[0]
|
|
198
|
+
if user.groups is None:
|
|
199
|
+
user.groups = []
|
|
200
|
+
print(f'User: {user}')
|
|
201
|
+
return user
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
@router.get("/load_user_plan")
|
|
205
|
+
def load_user(
|
|
206
|
+
current_user: models.User = Depends(get_current_user)
|
|
207
|
+
) -> dict:
|
|
208
|
+
return {'plan': _user_plan(current_user)}
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
@router.get("/load_user_by_id")
|
|
212
|
+
def load_user(
|
|
213
|
+
user_id: str,
|
|
214
|
+
) -> schema.User | None:
|
|
215
|
+
with compair.Session() as session:
|
|
216
|
+
q = select(models.User).filter(
|
|
217
|
+
models.User.user_id==user_id
|
|
218
|
+
)
|
|
219
|
+
user = session.execute(q).fetchone()
|
|
220
|
+
if user is None:
|
|
221
|
+
return
|
|
222
|
+
user = user[0]
|
|
223
|
+
if user.groups is None:
|
|
224
|
+
user.groups = []
|
|
225
|
+
print(f'User: {user}')
|
|
226
|
+
return user
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
@router.get("/load_user_files")
|
|
230
|
+
def load_user_files(
|
|
231
|
+
connection_id: str,
|
|
232
|
+
page: int = 1,
|
|
233
|
+
page_size: int = 10,
|
|
234
|
+
filter_type: str | None = None,
|
|
235
|
+
current_user: models.User = Depends(get_current_user)
|
|
236
|
+
) -> dict:
|
|
237
|
+
with compair.Session() as session:
|
|
238
|
+
# Validate connection
|
|
239
|
+
connection = session.query(models.User).filter(models.User.user_id == connection_id).first()
|
|
240
|
+
if not connection:
|
|
241
|
+
return {"files": [], "message": "User or connection not found."}
|
|
242
|
+
|
|
243
|
+
# Security: must share a group
|
|
244
|
+
shared_group_ids = set(g.group_id for g in current_user.groups).intersection(g.group_id for g in connection.groups)
|
|
245
|
+
if not shared_group_ids:
|
|
246
|
+
return {"files": [], "message": "You do not have permission to view this user's affiliations, as you do not share a group together."}
|
|
247
|
+
|
|
248
|
+
# Privacy: check setting
|
|
249
|
+
if connection.hide_affiliations:
|
|
250
|
+
return {"files": [], "message": f"{connection.username} has set their profile to private."}
|
|
251
|
+
|
|
252
|
+
now = datetime.now(timezone.utc)
|
|
253
|
+
week_ago = now - timedelta(days=7)
|
|
254
|
+
# Fetch documents belonging to shared or public groups
|
|
255
|
+
q = select(models.Document).join(models.Document.groups).filter(
|
|
256
|
+
models.Document.user_id == connection_id,
|
|
257
|
+
models.Group.group_id.in_(shared_group_ids) | # Shared groups
|
|
258
|
+
(
|
|
259
|
+
(models.Group.visibility == "public") & # Public groups
|
|
260
|
+
models.Group.users.any(models.User.user_id == connection_id) # Associated with the connection
|
|
261
|
+
)
|
|
262
|
+
).options(
|
|
263
|
+
joinedload(models.Document.groups),
|
|
264
|
+
joinedload(models.Document.user).joinedload(models.User.groups)
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
# --- Filter logic ---
|
|
268
|
+
if filter_type == "published":
|
|
269
|
+
q = q.filter(models.Document.is_published == True)
|
|
270
|
+
elif filter_type == "unpublished":
|
|
271
|
+
q = q.filter(models.Document.is_published == False)
|
|
272
|
+
elif filter_type == "recently_updated":
|
|
273
|
+
# Documents updated OR with a note in the past week
|
|
274
|
+
q = q.outerjoin(models.Document.notes).filter(
|
|
275
|
+
or_(
|
|
276
|
+
models.Document.datetime_modified >= week_ago,
|
|
277
|
+
models.Note.datetime_created >= week_ago
|
|
278
|
+
)
|
|
279
|
+
)
|
|
280
|
+
elif filter_type == "recently_compaired":
|
|
281
|
+
# Documents with feedback in the past week
|
|
282
|
+
q = q.join(models.Document.chunks).join(models.Chunk.feedbacks).filter(
|
|
283
|
+
models.Feedback.timestamp >= week_ago
|
|
284
|
+
)
|
|
285
|
+
# Default: all, sorted by last update
|
|
286
|
+
q = q.order_by(models.Document.datetime_modified.desc())
|
|
287
|
+
|
|
288
|
+
documents = session.execute(q).unique().fetchall()
|
|
289
|
+
|
|
290
|
+
if documents is None or len(documents)==0:
|
|
291
|
+
return {
|
|
292
|
+
"documents": [],
|
|
293
|
+
"total_count": 0,
|
|
294
|
+
"message": None
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
total_count = len(documents)
|
|
298
|
+
# Paging
|
|
299
|
+
offset = (page - 1) * page_size
|
|
300
|
+
documents = session.execute(q.order_by(models.Document.datetime_created.desc()).offset(offset).limit(page_size)).unique().fetchall()
|
|
301
|
+
|
|
302
|
+
files = [d[0] for d in documents] if documents else []
|
|
303
|
+
return {
|
|
304
|
+
"files": [schema.Document.model_validate(f) for f in files],
|
|
305
|
+
"message": None,
|
|
306
|
+
"total_count": total_count,
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
@router.get("/load_user_groups")
|
|
311
|
+
def load_user_groups(
|
|
312
|
+
connection_id: str,
|
|
313
|
+
page: int = 1,
|
|
314
|
+
page_size: int = 10,
|
|
315
|
+
filter_type: str | None = None,
|
|
316
|
+
current_user: models.User = Depends(get_current_user)
|
|
317
|
+
) -> dict:
|
|
318
|
+
with compair.Session() as session:
|
|
319
|
+
now = datetime.now(timezone.utc)
|
|
320
|
+
week_ago = now - timedelta(days=7)
|
|
321
|
+
# Validate connection
|
|
322
|
+
connection = session.query(models.User).filter(models.User.user_id == connection_id).first()
|
|
323
|
+
if not connection:
|
|
324
|
+
return {"groups": [], "message": "User or connection not found."}
|
|
325
|
+
|
|
326
|
+
# Check if the connection is valid
|
|
327
|
+
shared_group_ids = set([g.group_id for g in current_user.groups]).intersection([g.group_id for g in connection.groups])
|
|
328
|
+
if not shared_group_ids:
|
|
329
|
+
return {"groups": [], "message": "You do not have permission to view this user's affiliations, as you do not share a group together."}
|
|
330
|
+
|
|
331
|
+
# Privacy: check setting
|
|
332
|
+
if connection.hide_affiliations:
|
|
333
|
+
return {"groups": [], "message": f"{connection.username} has set their profile to private."}
|
|
334
|
+
|
|
335
|
+
# Fetch shared or public groups of the connection
|
|
336
|
+
groups_query = session.query(models.Group).filter(
|
|
337
|
+
models.Group.group_id.in_(shared_group_ids) | # Shared groups
|
|
338
|
+
(
|
|
339
|
+
(models.Group.visibility == "public") & # Public groups
|
|
340
|
+
models.Group.users.any(models.User.user_id == connection_id) # Associated with the connection
|
|
341
|
+
)
|
|
342
|
+
).options(
|
|
343
|
+
joinedload(models.Group.users),
|
|
344
|
+
joinedload(models.Group.documents)
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
# --- Filter logic (match /load_groups) ---
|
|
348
|
+
if filter_type == "internal":
|
|
349
|
+
groups_query = groups_query.filter(models.Group.visibility == "internal")
|
|
350
|
+
elif filter_type == "public":
|
|
351
|
+
groups_query = groups_query.filter(models.Group.visibility == "public")
|
|
352
|
+
elif filter_type == "private":
|
|
353
|
+
groups_query = groups_query.filter(models.Group.visibility == "private")
|
|
354
|
+
elif filter_type == "recently_updated":
|
|
355
|
+
groups_query = groups_query.join(models.Group.documents).filter(
|
|
356
|
+
models.Document.datetime_created >= week_ago
|
|
357
|
+
)
|
|
358
|
+
else:
|
|
359
|
+
groups_query = groups_query.order_by(models.Group.name.asc())
|
|
360
|
+
|
|
361
|
+
total_count = groups_query.count()
|
|
362
|
+
offset = (page - 1) * page_size
|
|
363
|
+
groups = groups_query.offset(offset).limit(page_size).all()
|
|
364
|
+
|
|
365
|
+
result = [
|
|
366
|
+
{
|
|
367
|
+
"group_id": group.group_id,
|
|
368
|
+
"name": group.name,
|
|
369
|
+
"datetime_created": group.datetime_created,
|
|
370
|
+
"group_image": group.group_image,
|
|
371
|
+
"category": group.category,
|
|
372
|
+
"description": group.description,
|
|
373
|
+
"visibility": group.visibility,
|
|
374
|
+
"document_count": getattr(group, "document_count", None),
|
|
375
|
+
"user_count": getattr(group, "user_count", None),
|
|
376
|
+
"first_three_user_profile_images": getattr(group, "first_three_user_profile_images", None)
|
|
377
|
+
}
|
|
378
|
+
for group in groups
|
|
379
|
+
]
|
|
380
|
+
return {"groups": result, "message": None, "total_count": total_count}
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
@router.get("/load_user_status")
|
|
384
|
+
def load_user_status(
|
|
385
|
+
current_user: models.User = Depends(get_current_user)
|
|
386
|
+
) -> str:
|
|
387
|
+
user_status = 'inactive'
|
|
388
|
+
with compair.Session() as session:
|
|
389
|
+
user_status = current_user.status
|
|
390
|
+
return user_status
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
@router.get("/load_user_status_date")
|
|
394
|
+
def load_user_status(
|
|
395
|
+
current_user: models.User = Depends(get_current_user)
|
|
396
|
+
) -> datetime:
|
|
397
|
+
if not (HAS_TRIALS or HAS_BILLING):
|
|
398
|
+
raise HTTPException(status_code=501, detail="User status dates are only tracked in the Compair Cloud edition.")
|
|
399
|
+
with compair.Session() as session:
|
|
400
|
+
user_status = current_user.status
|
|
401
|
+
if user_status=='active':
|
|
402
|
+
require_feature(HAS_BILLING, "Billing history")
|
|
403
|
+
user_status_date = current_user.last_payment_date
|
|
404
|
+
elif user_status=='trial':
|
|
405
|
+
require_feature(HAS_TRIALS, "Trial management")
|
|
406
|
+
user_status_date = current_user.trial_expiration_date
|
|
407
|
+
elif user_status=='suspended':
|
|
408
|
+
user_status_date = current_user.status_change_date
|
|
409
|
+
else:
|
|
410
|
+
raise HTTPException(status_code=403, detail='User Inactive')
|
|
411
|
+
return user_status_date
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
@router.get("/load_referral_credits")
|
|
415
|
+
def load_referral_credits(
|
|
416
|
+
current_user: models.User = Depends(get_current_user)
|
|
417
|
+
) -> Tuple[int, int]:
|
|
418
|
+
if not HAS_REFERRALS:
|
|
419
|
+
raise HTTPException(status_code=501, detail="Referral credits are only available in the Compair Cloud edition.")
|
|
420
|
+
with compair.Session() as session:
|
|
421
|
+
referral_credits_earned = current_user.referral_credits
|
|
422
|
+
referral_credits_pending = current_user.pending_referral_credits
|
|
423
|
+
return (referral_credits_earned, referral_credits_pending)
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def create_user(
|
|
427
|
+
username: str,
|
|
428
|
+
name: str,
|
|
429
|
+
password: str,
|
|
430
|
+
session: Session,
|
|
431
|
+
groups: list[str] | None = None,
|
|
432
|
+
referral_code: str = None
|
|
433
|
+
):
|
|
434
|
+
token, expiration = generate_verification_token()
|
|
435
|
+
existing_user = session.query(models.User).filter(models.User.username == username).first()
|
|
436
|
+
if existing_user:
|
|
437
|
+
raise HTTPException(status_code=400, detail="Email already in use")
|
|
438
|
+
|
|
439
|
+
user = models.User(
|
|
440
|
+
username=username,
|
|
441
|
+
name=name,
|
|
442
|
+
datetime_registered=datetime.now(),
|
|
443
|
+
verification_token=token.lower(),
|
|
444
|
+
token_expiration=expiration,
|
|
445
|
+
)
|
|
446
|
+
user.set_password(password=password)
|
|
447
|
+
session.add(user)
|
|
448
|
+
session.commit()
|
|
449
|
+
|
|
450
|
+
if HAS_TEAM:
|
|
451
|
+
team_invitations = session.query(models.TeamInvitation).filter(
|
|
452
|
+
models.TeamInvitation.email == username,
|
|
453
|
+
models.TeamInvitation.status == "pending",
|
|
454
|
+
).all()
|
|
455
|
+
|
|
456
|
+
for invitation in team_invitations:
|
|
457
|
+
emailer.connect()
|
|
458
|
+
emailer.send(
|
|
459
|
+
subject="You're Invited to Join a Team on Compair",
|
|
460
|
+
sender=EMAIL_USER,
|
|
461
|
+
receivers=[invitation.email],
|
|
462
|
+
html=f"""
|
|
463
|
+
<p>{invitation.inviter.name} has invited you to join their team on Compair!</p>
|
|
464
|
+
<p>Click <a href="https://{WEB_URL}/accept-invitation?token={invitation.invitation_id}">here</a> to join.</p>
|
|
465
|
+
"""
|
|
466
|
+
)
|
|
467
|
+
invitation.status = "sent"
|
|
468
|
+
session.commit()
|
|
469
|
+
|
|
470
|
+
if groups is not None:
|
|
471
|
+
user.groups = load_groups_by_ids(groups)
|
|
472
|
+
else:
|
|
473
|
+
# Add the welcome group and default user group
|
|
474
|
+
welcome_group = session.query(models.Group).filter(
|
|
475
|
+
models.Group.name == "Welcome to Compair"
|
|
476
|
+
).first()
|
|
477
|
+
welcome_doc_id = create_doc(
|
|
478
|
+
authorid=user.user_id,
|
|
479
|
+
document_title=f'Introducing {user.name} to Compair',
|
|
480
|
+
document_type='txt',
|
|
481
|
+
document_content="",
|
|
482
|
+
groups=welcome_group.group_id,
|
|
483
|
+
is_published=False,
|
|
484
|
+
current_user=user
|
|
485
|
+
)['document_id']
|
|
486
|
+
welcome_doc = session.query(models.Document).filter(
|
|
487
|
+
models.Document.document_id == welcome_doc_id
|
|
488
|
+
).first()
|
|
489
|
+
group = models.Group(
|
|
490
|
+
name=username,
|
|
491
|
+
datetime_created=datetime.now(),
|
|
492
|
+
group_image=None,
|
|
493
|
+
category="Private",
|
|
494
|
+
description=f"A private group for {username}",
|
|
495
|
+
visibility="private"
|
|
496
|
+
)
|
|
497
|
+
admin = models.Administrator(user_id=user.user_id)
|
|
498
|
+
session.add(admin)
|
|
499
|
+
session.add(group)
|
|
500
|
+
session.add(welcome_doc)
|
|
501
|
+
session.commit()
|
|
502
|
+
group.admins.append(admin)
|
|
503
|
+
user.groups = [group, welcome_group]
|
|
504
|
+
|
|
505
|
+
# Track referral if a code is provided
|
|
506
|
+
if referral_code:
|
|
507
|
+
require_feature(HAS_REFERRALS, "Referral program")
|
|
508
|
+
print(f'Got to backend referral code: {referral_code}')
|
|
509
|
+
referrer = session.query(models.User).filter(models.User.referral_code == referral_code).first()
|
|
510
|
+
if referrer and hasattr(referrer, "referral_credits"):
|
|
511
|
+
max_credits = 3
|
|
512
|
+
if referrer.referral_credits <= max_credits * 10: # $10 per credit
|
|
513
|
+
if hasattr(user, "referred_by"):
|
|
514
|
+
user.referred_by = referrer.user_id # Store who referred them
|
|
515
|
+
if hasattr(referrer, "pending_referral_credits"):
|
|
516
|
+
referrer.pending_referral_credits += 10 # Add pending credit
|
|
517
|
+
session.commit()
|
|
518
|
+
|
|
519
|
+
session.add(user)
|
|
520
|
+
session.commit()
|
|
521
|
+
try:
|
|
522
|
+
analytics.track("user_signup", user.user_id)
|
|
523
|
+
except Exception as exc:
|
|
524
|
+
print(f"analytics track failed: {exc}")
|
|
525
|
+
return user
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
@router.get("/load_session")
|
|
529
|
+
def load_session(auth_token: str) -> schema.Session | None:
|
|
530
|
+
with compair.Session() as session:
|
|
531
|
+
user_session = session.query(models.Session).filter(models.Session.id == auth_token).first()
|
|
532
|
+
if not user_session:
|
|
533
|
+
raise HTTPException(status_code=404, detail="Session not found")
|
|
534
|
+
valid_until = user_session.datetime_valid_until
|
|
535
|
+
if valid_until.tzinfo is None:
|
|
536
|
+
valid_until = valid_until.replace(tzinfo=timezone.utc)
|
|
537
|
+
if valid_until < datetime.now(timezone.utc):
|
|
538
|
+
raise HTTPException(status_code=401, detail="Invalid or expired session token")
|
|
539
|
+
return user_session
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
@router.post("/update_user")
|
|
543
|
+
def update_user(
|
|
544
|
+
name: str = Form(None),
|
|
545
|
+
role: str = Form(None),
|
|
546
|
+
group_ids: list[str] = Form(None),
|
|
547
|
+
include_own_documents_in_feedback: str = Form(None),
|
|
548
|
+
default_publish: str = Form(None),
|
|
549
|
+
preferred_feedback_length: str = Form(None),
|
|
550
|
+
hide_affiliations: str = Form(None),
|
|
551
|
+
current_user: models.User = Depends(get_current_user)
|
|
552
|
+
):
|
|
553
|
+
with compair.Session() as session:
|
|
554
|
+
if name is not None:
|
|
555
|
+
current_user.name = name
|
|
556
|
+
if role is not None:
|
|
557
|
+
current_user.role = role
|
|
558
|
+
if group_ids is not None:
|
|
559
|
+
groups = load_groups_by_ids(group_ids)
|
|
560
|
+
groups = [g for g in groups if g not in current_user.groups]
|
|
561
|
+
current_user.groups.extend(groups)
|
|
562
|
+
if include_own_documents_in_feedback is not None:
|
|
563
|
+
# Convert string to bool
|
|
564
|
+
current_user.include_own_documents_in_feedback = include_own_documents_in_feedback.lower() == "true"
|
|
565
|
+
if default_publish is not None:
|
|
566
|
+
current_user.default_publish = default_publish.lower() == "true"
|
|
567
|
+
if preferred_feedback_length is not None:
|
|
568
|
+
# Lock to Brief for the time being
|
|
569
|
+
preferred_feedback_length = 'Brief'
|
|
570
|
+
current_user.preferred_feedback_length = preferred_feedback_length
|
|
571
|
+
if hide_affiliations is not None:
|
|
572
|
+
current_user.hide_affiliations = hide_affiliations.lower() == "true"
|
|
573
|
+
session.add(current_user)
|
|
574
|
+
session.commit()
|
|
575
|
+
|
|
576
|
+
@router.get("/update_session_duration")
|
|
577
|
+
def update_session_duration(
|
|
578
|
+
user_session: schema.Session,
|
|
579
|
+
new_valid_until: datetime,
|
|
580
|
+
) -> None:
|
|
581
|
+
with compair.Session() as session:
|
|
582
|
+
user_session = session.query(
|
|
583
|
+
models.Session
|
|
584
|
+
).filter(
|
|
585
|
+
models.Session.id == user_session.id
|
|
586
|
+
).first()
|
|
587
|
+
user_session.update({'datetime_valid_until': new_valid_until})
|
|
588
|
+
session.commit()
|
|
589
|
+
|
|
590
|
+
|
|
591
|
+
@router.get("/delete_user")
|
|
592
|
+
def delete_user(
|
|
593
|
+
current_user: models.User = Depends(get_current_user)
|
|
594
|
+
):
|
|
595
|
+
with compair.Session() as session:
|
|
596
|
+
current_user.delete()
|
|
597
|
+
session.commit()
|
|
598
|
+
|
|
599
|
+
|
|
600
|
+
@router.get("/load_connections")
|
|
601
|
+
def load_connections(
|
|
602
|
+
page: int = 1,
|
|
603
|
+
page_size: int = 10,
|
|
604
|
+
filter_type: str | None = None,
|
|
605
|
+
current_user: models.User = Depends(get_current_user)
|
|
606
|
+
) -> dict | None:
|
|
607
|
+
with compair.Session() as session:
|
|
608
|
+
now = datetime.now(timezone.utc)
|
|
609
|
+
week_ago = now - timedelta(days=7)
|
|
610
|
+
|
|
611
|
+
# Get all groups the user belongs to
|
|
612
|
+
groups = session.query(models.Group).options(joinedload(models.Group.users)).filter(
|
|
613
|
+
models.Group.group_id.in_([g.group_id for g in current_user.groups])
|
|
614
|
+
).all()
|
|
615
|
+
if not groups:
|
|
616
|
+
return {"connections": [], "total_count": 0}
|
|
617
|
+
|
|
618
|
+
# Collect all user IDs from the groups
|
|
619
|
+
connection_ids = set()
|
|
620
|
+
for group in groups:
|
|
621
|
+
for group_user in group.users:
|
|
622
|
+
if group_user.user_id != current_user.user_id: # Exclude the requesting user
|
|
623
|
+
connection_ids.add(group_user.user_id)
|
|
624
|
+
|
|
625
|
+
# Fetch the User objects for the collected IDs
|
|
626
|
+
q = session.query(models.User).filter(models.User.user_id.in_(connection_ids))
|
|
627
|
+
|
|
628
|
+
# --- Filter logic ---
|
|
629
|
+
if filter_type == "recently_active":
|
|
630
|
+
q = q.join(models.User.activities).filter(
|
|
631
|
+
models.Activity.action == "create",
|
|
632
|
+
models.Activity.timestamp >= week_ago
|
|
633
|
+
)
|
|
634
|
+
elif filter_type == "recently_compaired":
|
|
635
|
+
# Their doc was used for feedback on your doc, or vice versa, in the past week
|
|
636
|
+
q = q.join(models.User.documents).join(models.Document.chunks).join(models.Chunk.feedbacks).filter(
|
|
637
|
+
models.Feedback.timestamp >= week_ago
|
|
638
|
+
)
|
|
639
|
+
else:
|
|
640
|
+
q = q.order_by(models.User.name.asc())
|
|
641
|
+
|
|
642
|
+
total_count = q.count()
|
|
643
|
+
offset = (page - 1) * page_size
|
|
644
|
+
connections = q.order_by(models.User.datetime_registered.desc()).offset(offset).limit(page_size).all()
|
|
645
|
+
|
|
646
|
+
# Convert users to dictionary format
|
|
647
|
+
return {
|
|
648
|
+
"connections": [
|
|
649
|
+
{
|
|
650
|
+
"user_id": connection.user_id,
|
|
651
|
+
"username": connection.username,
|
|
652
|
+
"name": connection.name,
|
|
653
|
+
"datetime_registered": connection.datetime_registered,
|
|
654
|
+
"status": connection.status,
|
|
655
|
+
"profile_image": connection.profile_image,
|
|
656
|
+
"role": connection.role,
|
|
657
|
+
}
|
|
658
|
+
for connection in connections
|
|
659
|
+
],
|
|
660
|
+
"total_count": total_count
|
|
661
|
+
}
|
|
662
|
+
|
|
663
|
+
|
|
664
|
+
@router.get("/all_group_categories")
|
|
665
|
+
def all_group_categories(
|
|
666
|
+
current_user: models.User = Depends(get_current_user)
|
|
667
|
+
):
|
|
668
|
+
"""Return all unique group categories."""
|
|
669
|
+
with compair.Session() as session:
|
|
670
|
+
categories = (
|
|
671
|
+
session.query(distinct(models.Group.category))
|
|
672
|
+
.order_by(models.Group.category.asc())
|
|
673
|
+
.all()
|
|
674
|
+
)
|
|
675
|
+
# Flatten list of tuples and filter out None/empty
|
|
676
|
+
categories = [c[0] for c in categories if c[0] and c[0]!='Compair']
|
|
677
|
+
return {"categories": categories}
|
|
678
|
+
|
|
679
|
+
|
|
680
|
+
@router.get("/load_groups")
|
|
681
|
+
def load_groups(
|
|
682
|
+
user_id: str | None = None,
|
|
683
|
+
page: int = 1,
|
|
684
|
+
page_size: int = 10,
|
|
685
|
+
filter_type: str | None = None,
|
|
686
|
+
category: str | None = None,
|
|
687
|
+
visibility: str | None = None,
|
|
688
|
+
sort: str | None = None,
|
|
689
|
+
query: str | None = None,
|
|
690
|
+
own_groups_only: bool = False,
|
|
691
|
+
current_user: models.User = Depends(get_current_user)
|
|
692
|
+
) -> dict | None:
|
|
693
|
+
with compair.Session() as session:
|
|
694
|
+
# --- User-based group selection ---
|
|
695
|
+
if user_id is None:
|
|
696
|
+
q = session.query(models.Group).options(
|
|
697
|
+
joinedload(models.Group.users),
|
|
698
|
+
joinedload(models.Group.documents)
|
|
699
|
+
).filter(
|
|
700
|
+
models.Group.visibility != 'private'
|
|
701
|
+
)
|
|
702
|
+
else:
|
|
703
|
+
user = session.query(models.User).options(
|
|
704
|
+
joinedload(models.User.groups).joinedload(models.Group.users),
|
|
705
|
+
joinedload(models.User.groups).joinedload(models.Group.documents)
|
|
706
|
+
).filter(
|
|
707
|
+
models.User.user_id == current_user.user_id
|
|
708
|
+
).first()
|
|
709
|
+
user_group_ids = [g.group_id for g in user.groups if g.category!='Compair']
|
|
710
|
+
|
|
711
|
+
invited_group_ids = set()
|
|
712
|
+
invitations = session.query(models.GroupInvitation).filter(
|
|
713
|
+
models.GroupInvitation.email == user.username,
|
|
714
|
+
models.GroupInvitation.status == "sent"
|
|
715
|
+
).all()
|
|
716
|
+
invited_group_ids.update([i.group_id for i in invitations])
|
|
717
|
+
|
|
718
|
+
accessible_group_ids = set(user_group_ids) | invited_group_ids
|
|
719
|
+
|
|
720
|
+
if own_groups_only:
|
|
721
|
+
# Only groups the user is a member of or invited
|
|
722
|
+
q = session.query(models.Group).filter(
|
|
723
|
+
models.Group.group_id.in_(accessible_group_ids)
|
|
724
|
+
)
|
|
725
|
+
else:
|
|
726
|
+
# All groups user can access: public, or private/internal if a member, or groups with an invitation
|
|
727
|
+
q = session.query(models.Group).filter(
|
|
728
|
+
(models.Group.visibility == "public") |
|
|
729
|
+
(models.Group.group_id.in_(accessible_group_ids))
|
|
730
|
+
)
|
|
731
|
+
|
|
732
|
+
# --- Filtering ---
|
|
733
|
+
if category and category.lower() != "all":
|
|
734
|
+
q = q.filter(models.Group.category == category)
|
|
735
|
+
if visibility and visibility.lower() != "all":
|
|
736
|
+
q = q.filter(models.Group.visibility == visibility)
|
|
737
|
+
if filter_type == "joined" and user_id:
|
|
738
|
+
#user = session.query(models.User).filter(models.User.user_id == user_id).first()
|
|
739
|
+
#q = q.filter(models.Group.group_id.in_([g.group_id for g in user.groups]))
|
|
740
|
+
# Already constrained above if own_groups_only is True
|
|
741
|
+
pass
|
|
742
|
+
if filter_type == "pending" and user_id:
|
|
743
|
+
# Groups with pending join requests or invitations for this user
|
|
744
|
+
user = session.query(models.User).filter(models.User.user_id == current_user.user_id).first()
|
|
745
|
+
pending_group_ids = set()
|
|
746
|
+
invitations = session.query(models.GroupInvitation).filter(
|
|
747
|
+
models.GroupInvitation.email == user.username,
|
|
748
|
+
models.GroupInvitation.status == "sent"
|
|
749
|
+
).all()
|
|
750
|
+
pending_group_ids.update([i.group_id for i in invitations])
|
|
751
|
+
q = q.filter(models.Group.group_id.in_(pending_group_ids))
|
|
752
|
+
if query:
|
|
753
|
+
q = q.filter(models.Group.name.ilike(f"%{query}%"))
|
|
754
|
+
# --- Sorting ---
|
|
755
|
+
if sort == "popular":
|
|
756
|
+
q = q.outerjoin(models.Group.users).group_by(models.Group.group_id).order_by(func.count(models.User.user_id).desc())
|
|
757
|
+
elif sort == "recently_updated":
|
|
758
|
+
q = q.outerjoin(models.Group.documents).group_by(models.Group.group_id).order_by(func.max(models.Document.datetime_modified).desc())
|
|
759
|
+
elif sort == "recently_created":
|
|
760
|
+
q = q.order_by(models.Group.datetime_created.desc())
|
|
761
|
+
else:
|
|
762
|
+
q = q.order_by(models.Group.name.asc())
|
|
763
|
+
# --- Paging ---
|
|
764
|
+
total_count = q.count()
|
|
765
|
+
offset = (page - 1) * page_size
|
|
766
|
+
groups = q.offset(offset).limit(page_size).all()
|
|
767
|
+
|
|
768
|
+
result = [
|
|
769
|
+
{
|
|
770
|
+
"group_id": group.group_id,
|
|
771
|
+
"name": group.name,
|
|
772
|
+
"datetime_created": group.datetime_created,
|
|
773
|
+
"group_image": group.group_image,
|
|
774
|
+
"category": group.category,
|
|
775
|
+
"description": group.description,
|
|
776
|
+
"visibility": group.visibility,
|
|
777
|
+
"document_count": group.document_count,
|
|
778
|
+
"user_count": group.user_count,
|
|
779
|
+
"first_three_user_profile_images": group.first_three_user_profile_images
|
|
780
|
+
}
|
|
781
|
+
for group in groups
|
|
782
|
+
]
|
|
783
|
+
return {"groups": result, "total_count": total_count}
|
|
784
|
+
|
|
785
|
+
|
|
786
|
+
def load_groups_by_ids(group_ids: list[str]) -> list[schema.Group]:
|
|
787
|
+
with compair.Session() as session:
|
|
788
|
+
q = session.query(models.Group).filter(
|
|
789
|
+
models.Group.group_id.in_(group_ids)
|
|
790
|
+
)
|
|
791
|
+
return q.all() # Returns list of Group objects directly
|
|
792
|
+
|
|
793
|
+
|
|
794
|
+
@router.get("/load_group")
|
|
795
|
+
def load_group(
|
|
796
|
+
name: str | None = None,
|
|
797
|
+
group_id: str | None = None
|
|
798
|
+
) -> schema.Group | None:
|
|
799
|
+
if (name is not None) or (group_id is not None):
|
|
800
|
+
with compair.Session() as session:
|
|
801
|
+
if group_id is not None:
|
|
802
|
+
q = select(models.Group).filter(
|
|
803
|
+
models.Group.group_id==group_id
|
|
804
|
+
)
|
|
805
|
+
else:
|
|
806
|
+
q = select(models.Group).filter(
|
|
807
|
+
models.Group.name.match(name)
|
|
808
|
+
)
|
|
809
|
+
group = session.execute(q).fetchone()
|
|
810
|
+
if group is None:
|
|
811
|
+
return None
|
|
812
|
+
return group[0]
|
|
813
|
+
|
|
814
|
+
|
|
815
|
+
def notify_group_admins(
|
|
816
|
+
group: models.Group,
|
|
817
|
+
user_id: str
|
|
818
|
+
):
|
|
819
|
+
"""Send an email notification to group admins."""
|
|
820
|
+
with compair.Session() as session:
|
|
821
|
+
admin_emails = [admin.user.username for admin in group.admins]
|
|
822
|
+
if len(admin_emails) == 0:
|
|
823
|
+
print("No admins found for group:", group.name)
|
|
824
|
+
return
|
|
825
|
+
user = session.query(models.User).filter(models.User.user_id == user_id).first()
|
|
826
|
+
emailer.connect()
|
|
827
|
+
print("Admin emails:", admin_emails)
|
|
828
|
+
emailer.send(
|
|
829
|
+
subject="Group Join Request",
|
|
830
|
+
sender=EMAIL_USER,
|
|
831
|
+
receivers=admin_emails,
|
|
832
|
+
html=GROUP_JOIN_TEMPLATE.replace(
|
|
833
|
+
"{{ user_name }}", user.username
|
|
834
|
+
).replace(
|
|
835
|
+
"{{ group_name }}", group.name
|
|
836
|
+
).replace(
|
|
837
|
+
"{{ admin_panel_url }}", f"http://{WEB_URL}/admin/groups"
|
|
838
|
+
)
|
|
839
|
+
)
|
|
840
|
+
|
|
841
|
+
@router.post("/join_group")
|
|
842
|
+
def join_group(
|
|
843
|
+
group_id: str = Form(...),
|
|
844
|
+
current_user: models.User = Depends(get_current_user)
|
|
845
|
+
):
|
|
846
|
+
return join_group_direct(
|
|
847
|
+
user_id=current_user.user_id,
|
|
848
|
+
group_id=group_id
|
|
849
|
+
)
|
|
850
|
+
|
|
851
|
+
def join_group_direct(
|
|
852
|
+
user_id: str,
|
|
853
|
+
group_id: str
|
|
854
|
+
):
|
|
855
|
+
"""Join a group based on its visibility."""
|
|
856
|
+
print("1")
|
|
857
|
+
with compair.Session() as session:
|
|
858
|
+
group = session.query(models.Group).filter(models.Group.group_id == group_id).first()
|
|
859
|
+
print(group)
|
|
860
|
+
if not group:
|
|
861
|
+
raise HTTPException(status_code=404, detail="Group not found")
|
|
862
|
+
print(group.visibility)
|
|
863
|
+
if group.visibility in ["public", "internal"]:
|
|
864
|
+
user = session.query(models.User).filter(models.User.user_id == user_id).first()
|
|
865
|
+
# Look for any existing invitations associated with this group
|
|
866
|
+
invitations = session.query(models.GroupInvitation).filter(
|
|
867
|
+
models.GroupInvitation.group_id == group_id,
|
|
868
|
+
models.GroupInvitation.email == user.username,
|
|
869
|
+
models.GroupInvitation.status == "sent"
|
|
870
|
+
).all()
|
|
871
|
+
for invitation in invitations:
|
|
872
|
+
invitation.status = "accepted"
|
|
873
|
+
|
|
874
|
+
if group.visibility == "public":
|
|
875
|
+
group.users.append(user)
|
|
876
|
+
session.commit()
|
|
877
|
+
log_activity(
|
|
878
|
+
session=session,
|
|
879
|
+
user_id=user_id,
|
|
880
|
+
group_id=group.group_id,
|
|
881
|
+
action="join",
|
|
882
|
+
object_id=group.group_id,
|
|
883
|
+
object_name=group.name,
|
|
884
|
+
object_type="group"
|
|
885
|
+
)
|
|
886
|
+
return {"message": "Joined group successfully"}
|
|
887
|
+
|
|
888
|
+
elif group.visibility == "internal":
|
|
889
|
+
if len(invitations)>0:
|
|
890
|
+
# Invitation found; add to group
|
|
891
|
+
group.users.append(user)
|
|
892
|
+
session.commit()
|
|
893
|
+
else:
|
|
894
|
+
# Create a JoinRequest if not already present
|
|
895
|
+
existing_request = session.query(models.JoinRequest).filter(
|
|
896
|
+
models.JoinRequest.user_id == user_id,
|
|
897
|
+
models.JoinRequest.group_id == group_id
|
|
898
|
+
).first()
|
|
899
|
+
if not existing_request:
|
|
900
|
+
join_request = models.JoinRequest(
|
|
901
|
+
user_id=user_id,
|
|
902
|
+
group_id=group_id,
|
|
903
|
+
datetime_requested=datetime.now(timezone.utc)
|
|
904
|
+
)
|
|
905
|
+
session.add(join_request)
|
|
906
|
+
session.commit()
|
|
907
|
+
notify_group_admins(group, user_id)
|
|
908
|
+
return {"message": "Join request sent to group admins"}
|
|
909
|
+
|
|
910
|
+
elif group.visibility == "private":
|
|
911
|
+
raise HTTPException(status_code=403, detail="Cannot join private group without an invite")
|
|
912
|
+
|
|
913
|
+
|
|
914
|
+
@router.post("/create_group")
|
|
915
|
+
async def create_group(
|
|
916
|
+
name: str = Form(...),
|
|
917
|
+
category: str = Form(None),
|
|
918
|
+
description: str = Form(None),
|
|
919
|
+
visibility: str = Form("public"),
|
|
920
|
+
file: UploadFile = File(None), # Allow optional file upload
|
|
921
|
+
current_user: models.User = Depends(get_current_user)
|
|
922
|
+
):
|
|
923
|
+
with compair.Session() as session:
|
|
924
|
+
print('1')
|
|
925
|
+
if category not in all_group_categories()['categories']:
|
|
926
|
+
category = "Other" # Default to "Other" if category is not valid
|
|
927
|
+
|
|
928
|
+
# Limit internal group creation to active, team plans
|
|
929
|
+
if visibility == 'internal' and not (current_user.status == 'active' and _user_plan(current_user) == 'team'):
|
|
930
|
+
raise HTTPException(
|
|
931
|
+
status_code=403,
|
|
932
|
+
detail="Internal groups can only be created by users with an active team plan"
|
|
933
|
+
)
|
|
934
|
+
|
|
935
|
+
created_group = models.Group(
|
|
936
|
+
name=name,
|
|
937
|
+
group_image=None,
|
|
938
|
+
category=category,
|
|
939
|
+
description=description,
|
|
940
|
+
visibility=visibility,
|
|
941
|
+
datetime_created=datetime.now(),
|
|
942
|
+
)
|
|
943
|
+
print(created_group)
|
|
944
|
+
# Check if user has an admin ID
|
|
945
|
+
q = select(models.Administrator).filter(
|
|
946
|
+
models.Administrator.user_id==current_user.user_id
|
|
947
|
+
)
|
|
948
|
+
admin = session.execute(q).fetchone()
|
|
949
|
+
print(admin)
|
|
950
|
+
if admin is None:
|
|
951
|
+
# Make the user an admin
|
|
952
|
+
admin = models.Administrator(user_id=current_user.user_id)
|
|
953
|
+
session.add(admin)
|
|
954
|
+
session.commit()
|
|
955
|
+
else:
|
|
956
|
+
admin = admin[0]
|
|
957
|
+
print('3?')
|
|
958
|
+
created_group.admins.append(admin)
|
|
959
|
+
session.add(created_group)
|
|
960
|
+
session.commit()
|
|
961
|
+
|
|
962
|
+
# Log activity
|
|
963
|
+
log_activity(
|
|
964
|
+
session=session,
|
|
965
|
+
user_id=current_user.user_id,
|
|
966
|
+
group_id=created_group.group_id,
|
|
967
|
+
action="create",
|
|
968
|
+
object_id=created_group.group_id,
|
|
969
|
+
object_name=created_group.name,
|
|
970
|
+
object_type="group"
|
|
971
|
+
)
|
|
972
|
+
|
|
973
|
+
# Add group to user
|
|
974
|
+
admin.user.groups.append(created_group)
|
|
975
|
+
print('4??')
|
|
976
|
+
session.add(admin)
|
|
977
|
+
session.commit()
|
|
978
|
+
|
|
979
|
+
if file is not None:
|
|
980
|
+
await upload_group_image(
|
|
981
|
+
group_id=created_group.group_id,
|
|
982
|
+
upload_type='group',
|
|
983
|
+
file=file
|
|
984
|
+
)
|
|
985
|
+
|
|
986
|
+
|
|
987
|
+
@router.get("/load_documents")
|
|
988
|
+
def load_doc(
|
|
989
|
+
group_id: str | None = None,
|
|
990
|
+
page: int = 1,
|
|
991
|
+
page_size: int = 10,
|
|
992
|
+
filter_type: str | None = None,
|
|
993
|
+
own_documents_only: bool = True,
|
|
994
|
+
current_user: models.User = Depends(get_current_user)
|
|
995
|
+
) -> Mapping[str, Any] | None:
|
|
996
|
+
with compair.Session() as session:
|
|
997
|
+
now = datetime.now(timezone.utc)
|
|
998
|
+
week_ago = now - timedelta(days=7)
|
|
999
|
+
|
|
1000
|
+
# Get user and their group memberships
|
|
1001
|
+
user_group_ids = set(g.group_id for g in current_user.groups)
|
|
1002
|
+
|
|
1003
|
+
q = session.query(models.Document)
|
|
1004
|
+
|
|
1005
|
+
if own_documents_only:
|
|
1006
|
+
q = q.filter(models.Document.user_id == current_user.user_id)
|
|
1007
|
+
else:
|
|
1008
|
+
q = q.join(models.Document.groups)
|
|
1009
|
+
if group_id:
|
|
1010
|
+
q = q.filter(models.Group.group_id == group_id)
|
|
1011
|
+
q = q.filter(
|
|
1012
|
+
(models.Group.visibility == "public") |
|
|
1013
|
+
(
|
|
1014
|
+
models.Group.group_id.in_(user_group_ids)
|
|
1015
|
+
)
|
|
1016
|
+
).options(
|
|
1017
|
+
joinedload(models.Document.groups),
|
|
1018
|
+
joinedload(models.Document.user).joinedload(models.User.groups)
|
|
1019
|
+
)
|
|
1020
|
+
|
|
1021
|
+
if group_id is not None and own_documents_only:
|
|
1022
|
+
q = q.filter(models.Document.groups.any(models.Group.group_id == group_id))
|
|
1023
|
+
|
|
1024
|
+
# --- Filter logic: publishing ---
|
|
1025
|
+
if filter_type == "unpublished" and own_documents_only:
|
|
1026
|
+
q = q.filter(models.Document.is_published == False)
|
|
1027
|
+
else:
|
|
1028
|
+
q = q.filter(
|
|
1029
|
+
or_(
|
|
1030
|
+
models.Document.is_published == True,
|
|
1031
|
+
models.Document.user_id == current_user.user_id
|
|
1032
|
+
)
|
|
1033
|
+
)
|
|
1034
|
+
|
|
1035
|
+
# --- Filter logic: other ---
|
|
1036
|
+
if filter_type == "recently_updated":
|
|
1037
|
+
# Documents updated OR with a note in the past week
|
|
1038
|
+
q = q.outerjoin(models.Document.notes).filter(
|
|
1039
|
+
or_(
|
|
1040
|
+
models.Document.datetime_modified >= week_ago,
|
|
1041
|
+
models.Note.datetime_created >= week_ago
|
|
1042
|
+
)
|
|
1043
|
+
)
|
|
1044
|
+
elif filter_type == "recently_compaired":
|
|
1045
|
+
# Documents with feedback in the past week
|
|
1046
|
+
q = q.join(models.Document.chunks).join(models.Chunk.feedbacks).filter(
|
|
1047
|
+
models.Feedback.timestamp >= week_ago
|
|
1048
|
+
)
|
|
1049
|
+
# Default: all, sorted by last update
|
|
1050
|
+
q = q.order_by(models.Document.datetime_modified.desc())
|
|
1051
|
+
|
|
1052
|
+
documents = session.execute(q).unique().fetchall()
|
|
1053
|
+
print(documents)
|
|
1054
|
+
if documents is None or len(documents)==0:
|
|
1055
|
+
return {
|
|
1056
|
+
"documents": [],
|
|
1057
|
+
"total_count": 0
|
|
1058
|
+
}
|
|
1059
|
+
|
|
1060
|
+
total_count = q.count()
|
|
1061
|
+
|
|
1062
|
+
# Paging
|
|
1063
|
+
offset = (page - 1) * page_size
|
|
1064
|
+
documents = session.execute(q.order_by(models.Document.datetime_created.desc()).offset(offset).limit(page_size)).unique().fetchall()
|
|
1065
|
+
#print(documents)
|
|
1066
|
+
if documents is None or len(documents)==0:
|
|
1067
|
+
return {
|
|
1068
|
+
"documents": [],
|
|
1069
|
+
"total_count": 0
|
|
1070
|
+
}
|
|
1071
|
+
#print(f'API returning these documents: {documents}')
|
|
1072
|
+
print(f'Total count: {total_count}')
|
|
1073
|
+
print(f'Page: {page}, Page size: {page_size}, Offset: {offset}')
|
|
1074
|
+
return {
|
|
1075
|
+
"documents": [
|
|
1076
|
+
schema.Document.model_validate(d[0]) for d in documents
|
|
1077
|
+
],
|
|
1078
|
+
"total_count": total_count
|
|
1079
|
+
}
|
|
1080
|
+
|
|
1081
|
+
|
|
1082
|
+
@router.get("/load_group_users")
|
|
1083
|
+
def load_group_users(
|
|
1084
|
+
group_id: str,
|
|
1085
|
+
page: int = 1,
|
|
1086
|
+
page_size: int = 10,
|
|
1087
|
+
filter_type: str | None = None,
|
|
1088
|
+
current_user: models.User = Depends(get_current_user)
|
|
1089
|
+
) -> dict:
|
|
1090
|
+
print('1')
|
|
1091
|
+
print(current_user.user_id)
|
|
1092
|
+
print(group_id)
|
|
1093
|
+
with compair.Session() as session:
|
|
1094
|
+
# Check if the group exists
|
|
1095
|
+
group = session.query(models.Group).filter(models.Group.group_id == group_id).first()
|
|
1096
|
+
print(group)
|
|
1097
|
+
if not group:
|
|
1098
|
+
raise HTTPException(status_code=404, detail="Group not found")
|
|
1099
|
+
|
|
1100
|
+
if (current_user not in group.users) & (group.visibility!='public'):
|
|
1101
|
+
raise HTTPException(status_code=403, detail="User does not belong to the group")
|
|
1102
|
+
|
|
1103
|
+
# Retrieve all users associated with the group
|
|
1104
|
+
users_query = session.query(models.User).join(models.User.groups).filter(models.Group.group_id == group_id)
|
|
1105
|
+
|
|
1106
|
+
# --- Filter logic ---
|
|
1107
|
+
if filter_type == "recently_active":
|
|
1108
|
+
users_query = users_query.order_by(models.User.status_change_date.desc())
|
|
1109
|
+
elif filter_type == "recently_joined":
|
|
1110
|
+
users_query = users_query.order_by(models.User.datetime_registered.desc())
|
|
1111
|
+
else:
|
|
1112
|
+
users_query = users_query.order_by(models.User.name.asc())
|
|
1113
|
+
|
|
1114
|
+
total_count = users_query.count()
|
|
1115
|
+
offset = (page - 1) * page_size
|
|
1116
|
+
users = users_query.offset(offset).limit(page_size).all()
|
|
1117
|
+
|
|
1118
|
+
# Convert users to schema objects
|
|
1119
|
+
return {
|
|
1120
|
+
"users": [
|
|
1121
|
+
{
|
|
1122
|
+
"user_id": u.user_id,
|
|
1123
|
+
"username": u.username,
|
|
1124
|
+
"name": u.name,
|
|
1125
|
+
"datetime_registered": u.datetime_registered,
|
|
1126
|
+
"status": u.status,
|
|
1127
|
+
"profile_image": u.profile_image,
|
|
1128
|
+
"role": u.role,
|
|
1129
|
+
}
|
|
1130
|
+
for u in users
|
|
1131
|
+
],
|
|
1132
|
+
"total_count": total_count
|
|
1133
|
+
}
|
|
1134
|
+
|
|
1135
|
+
|
|
1136
|
+
@router.get("/load_document")
|
|
1137
|
+
def load_doc(
|
|
1138
|
+
title: str,
|
|
1139
|
+
current_user: models.User = Depends(get_current_user)
|
|
1140
|
+
) -> schema.Document | None:
|
|
1141
|
+
with compair.Session() as session:
|
|
1142
|
+
q = select(models.Document).filter(
|
|
1143
|
+
models.Document.user_id==current_user.user_id
|
|
1144
|
+
).filter(
|
|
1145
|
+
models.Document.title.match(title)
|
|
1146
|
+
).options(
|
|
1147
|
+
joinedload(models.Document.groups),
|
|
1148
|
+
joinedload(models.Document.user).joinedload(models.User.groups)
|
|
1149
|
+
)
|
|
1150
|
+
document = session.execute(q).unique().fetchone()
|
|
1151
|
+
if document is None:
|
|
1152
|
+
return None
|
|
1153
|
+
return document[0]
|
|
1154
|
+
|
|
1155
|
+
|
|
1156
|
+
@router.get("/load_document_by_id")
|
|
1157
|
+
def load_doc(
|
|
1158
|
+
document_id: str,
|
|
1159
|
+
current_user: models.User = Depends(get_current_user)
|
|
1160
|
+
) -> schema.Document | None:
|
|
1161
|
+
with compair.Session() as session:
|
|
1162
|
+
q = select(models.Document).filter(
|
|
1163
|
+
models.Document.document_id==document_id
|
|
1164
|
+
).options(
|
|
1165
|
+
joinedload(models.Document.groups),
|
|
1166
|
+
joinedload(models.Document.user).joinedload(models.User.groups)
|
|
1167
|
+
)
|
|
1168
|
+
document = session.execute(q).unique().fetchone()
|
|
1169
|
+
if document is None:
|
|
1170
|
+
return None
|
|
1171
|
+
doc = document[0]
|
|
1172
|
+
doc_group_ids = {g.group_id for g in doc.groups}
|
|
1173
|
+
user_group_ids = {g.group_id for g in current_user.groups}
|
|
1174
|
+
if not doc_group_ids & user_group_ids and current_user.user_id != doc.author_id:
|
|
1175
|
+
raise HTTPException(status_code=403, detail="Not authorized to view this document")
|
|
1176
|
+
return doc
|
|
1177
|
+
|
|
1178
|
+
|
|
1179
|
+
@router.post("/update_doc")
|
|
1180
|
+
def update_doc(
|
|
1181
|
+
doc_id: str = Form(...),
|
|
1182
|
+
author_id: str = Form(None),
|
|
1183
|
+
title: str = Form(None),
|
|
1184
|
+
datetime_created: datetime = Form(None),
|
|
1185
|
+
group_ids: list[str] = Form(None),
|
|
1186
|
+
image_url: str = Form(None),
|
|
1187
|
+
current_user: models.User = Depends(get_current_user)
|
|
1188
|
+
):
|
|
1189
|
+
print('In update doc')
|
|
1190
|
+
print(author_id)
|
|
1191
|
+
print(title)
|
|
1192
|
+
print(datetime_created)
|
|
1193
|
+
print(group_ids)
|
|
1194
|
+
print(image_url)
|
|
1195
|
+
with compair.Session() as session:
|
|
1196
|
+
doc = session.query(models.Document).filter(
|
|
1197
|
+
models.Document.document_id == doc_id,
|
|
1198
|
+
models.Document.user_id == current_user.user_id
|
|
1199
|
+
).first()
|
|
1200
|
+
if doc:
|
|
1201
|
+
if author_id is not None:
|
|
1202
|
+
doc.author_id = author_id
|
|
1203
|
+
if title is not None:
|
|
1204
|
+
doc.title = title
|
|
1205
|
+
if datetime_created is not None:
|
|
1206
|
+
doc.datetime_created = datetime_created
|
|
1207
|
+
if group_ids is not None:
|
|
1208
|
+
groups = load_groups_by_ids(group_ids)
|
|
1209
|
+
doc.groups = []
|
|
1210
|
+
doc.groups.extend(groups)
|
|
1211
|
+
print(f'New groups here? {doc.groups}')
|
|
1212
|
+
if image_url is not None:
|
|
1213
|
+
doc.image_url = image_url
|
|
1214
|
+
session.commit()
|
|
1215
|
+
|
|
1216
|
+
|
|
1217
|
+
@router.post("/create_doc")
|
|
1218
|
+
def create_doc(
|
|
1219
|
+
authorid: str = Form(None),
|
|
1220
|
+
document_title: str = Form(None),
|
|
1221
|
+
document_type: str = Form(None),
|
|
1222
|
+
document_content: str = Form(""),
|
|
1223
|
+
groups: str = Form(None), # TODO: Fix how these get submitted; current comma-separated list string
|
|
1224
|
+
is_published: bool = Form(False),
|
|
1225
|
+
current_user: models.User = Depends(get_current_user),
|
|
1226
|
+
analytics: Analytics = Depends(get_analytics),
|
|
1227
|
+
):
|
|
1228
|
+
with compair.Session() as session:
|
|
1229
|
+
# Check if the trial has expired
|
|
1230
|
+
current_user = session.query(models.User).filter(models.User.user_id == current_user.user_id).first()
|
|
1231
|
+
trial_expiration = _trial_expiration(current_user)
|
|
1232
|
+
if HAS_TRIALS and trial_expiration and current_user.status == "trial" and trial_expiration < datetime.now(timezone.utc):
|
|
1233
|
+
current_user.status = "suspended" # Mark as suspended once the trial expires
|
|
1234
|
+
current_user.status_change_date = datetime.now(timezone.utc)
|
|
1235
|
+
session.commit()
|
|
1236
|
+
|
|
1237
|
+
# Enforce document limits
|
|
1238
|
+
team = _user_team(current_user)
|
|
1239
|
+
if IS_CLOUD and HAS_TEAM and team and current_user.status == "active":
|
|
1240
|
+
document_limit = team.total_documents_limit # type: ignore[union-attr]
|
|
1241
|
+
else:
|
|
1242
|
+
if IS_CLOUD and _user_plan(current_user) == 'individual' and current_user.status == "active":
|
|
1243
|
+
document_limit = 100
|
|
1244
|
+
else:
|
|
1245
|
+
document_limit = int(os.getenv("COMPAIR_CORE_DOCUMENT_LIMIT", "10"))
|
|
1246
|
+
document_count = session.query(models.Document).filter(models.Document.user_id == current_user.user_id).count()
|
|
1247
|
+
|
|
1248
|
+
if document_count >= document_limit:
|
|
1249
|
+
raise HTTPException(
|
|
1250
|
+
status_code=403,
|
|
1251
|
+
detail=f"Document limit reached. Individual plan users can have 100, team plans have 100 times the number of users (pooled); other plans can have 10",
|
|
1252
|
+
)
|
|
1253
|
+
|
|
1254
|
+
if not authorid:
|
|
1255
|
+
authorid = current_user.user_id
|
|
1256
|
+
|
|
1257
|
+
document = models.Document(
|
|
1258
|
+
user_id=current_user.user_id,
|
|
1259
|
+
author_id=authorid,
|
|
1260
|
+
title=document_title,
|
|
1261
|
+
content=document_content,
|
|
1262
|
+
doc_type=document_type,
|
|
1263
|
+
datetime_created=datetime.now(timezone.utc),
|
|
1264
|
+
datetime_modified=datetime.now(timezone.utc)
|
|
1265
|
+
)
|
|
1266
|
+
print('About to assign groups!')
|
|
1267
|
+
print(groups.split(','))
|
|
1268
|
+
if groups is not None:
|
|
1269
|
+
group_ids = groups.split(',')
|
|
1270
|
+
q = select(models.Group).filter(
|
|
1271
|
+
models.Group.group_id.in_(group_ids)
|
|
1272
|
+
)
|
|
1273
|
+
groups = session.execute(q).fetchall()[0]
|
|
1274
|
+
for group in groups:
|
|
1275
|
+
document.groups.append(group)
|
|
1276
|
+
else:
|
|
1277
|
+
q = select(models.Group).filter(
|
|
1278
|
+
models.Group.name == current_user.username
|
|
1279
|
+
)
|
|
1280
|
+
group = session.execute(q).fetchone()[0]
|
|
1281
|
+
document.groups = [group]
|
|
1282
|
+
|
|
1283
|
+
print(f'doc check!!! {document.content}')
|
|
1284
|
+
session.add(document)
|
|
1285
|
+
session.commit()
|
|
1286
|
+
|
|
1287
|
+
if is_published:
|
|
1288
|
+
# Attempt to publish doc, pending user status check
|
|
1289
|
+
publish_doc(
|
|
1290
|
+
doc_id=document.document_id,
|
|
1291
|
+
is_published=True,
|
|
1292
|
+
current_user=current_user
|
|
1293
|
+
)
|
|
1294
|
+
|
|
1295
|
+
# Log document creation
|
|
1296
|
+
log_activity(
|
|
1297
|
+
session=session,
|
|
1298
|
+
user_id=document.author_id,
|
|
1299
|
+
group_id=group.group_id,
|
|
1300
|
+
action="create",
|
|
1301
|
+
object_id=document.document_id,
|
|
1302
|
+
object_name=document.title,
|
|
1303
|
+
object_type="document"
|
|
1304
|
+
)
|
|
1305
|
+
# Return the document_id for frontend use
|
|
1306
|
+
try:
|
|
1307
|
+
analytics.track("document_created", document.user_id)
|
|
1308
|
+
except Exception as exc:
|
|
1309
|
+
print(f"analytics track failed: {exc}")
|
|
1310
|
+
return {"document_id": document.document_id}
|
|
1311
|
+
|
|
1312
|
+
|
|
1313
|
+
@router.get("/publish_doc")
|
|
1314
|
+
def publish_doc(
|
|
1315
|
+
doc_id: str,
|
|
1316
|
+
is_published: bool,
|
|
1317
|
+
current_user: models.User = Depends(get_current_user)
|
|
1318
|
+
):
|
|
1319
|
+
with compair.Session() as session:
|
|
1320
|
+
# Check if user can publish or not
|
|
1321
|
+
if current_user.status == "suspended":
|
|
1322
|
+
raise HTTPException(status_code=403, detail="Your subscription has expired. Renew to publish.")
|
|
1323
|
+
|
|
1324
|
+
doc = session.query(models.Document).filter(
|
|
1325
|
+
models.Document.document_id == doc_id,
|
|
1326
|
+
models.Document.user_id == current_user.user_id
|
|
1327
|
+
)
|
|
1328
|
+
if not doc.first():
|
|
1329
|
+
raise HTTPException(status_code=404, detail="Document not found or you do not have permission to publish it.")
|
|
1330
|
+
if is_published!=doc.first().is_published:
|
|
1331
|
+
doc.update({'is_published': is_published})
|
|
1332
|
+
session.commit()
|
|
1333
|
+
|
|
1334
|
+
|
|
1335
|
+
@router.get("/delete_docs")
|
|
1336
|
+
def delete_docs(
|
|
1337
|
+
doc_ids: str,
|
|
1338
|
+
current_user: models.User = Depends(get_current_user)
|
|
1339
|
+
):
|
|
1340
|
+
with compair.Session() as session:
|
|
1341
|
+
doc_ids = doc_ids.split(',')
|
|
1342
|
+
documents = session.query(models.Document).filter(
|
|
1343
|
+
models.Document.document_id.in_(doc_ids),
|
|
1344
|
+
models.Document.user_id == current_user.user_id
|
|
1345
|
+
)
|
|
1346
|
+
|
|
1347
|
+
for document in documents:
|
|
1348
|
+
doc_id = document.document_id
|
|
1349
|
+
doc_name = document.title
|
|
1350
|
+
doc_group: models.Group = document.groups[0]
|
|
1351
|
+
group_id = doc_group.group_id
|
|
1352
|
+
log_activity(
|
|
1353
|
+
session=session,
|
|
1354
|
+
user_id=current_user.user_id,
|
|
1355
|
+
group_id=group_id,
|
|
1356
|
+
action="delete",
|
|
1357
|
+
object_id=doc_id,
|
|
1358
|
+
object_name=doc_name,
|
|
1359
|
+
object_type="document"
|
|
1360
|
+
)
|
|
1361
|
+
|
|
1362
|
+
documents.delete()
|
|
1363
|
+
session.commit()
|
|
1364
|
+
|
|
1365
|
+
|
|
1366
|
+
@router.get("/delete_doc")
|
|
1367
|
+
def delete_doc(
|
|
1368
|
+
doc_id: str,
|
|
1369
|
+
current_user: models.User = Depends(get_current_user)
|
|
1370
|
+
):
|
|
1371
|
+
with compair.Session() as session:
|
|
1372
|
+
document = session.query(models.Document).filter(
|
|
1373
|
+
models.Document.document_id == doc_id,
|
|
1374
|
+
models.Document.user_id == current_user.user_id
|
|
1375
|
+
)
|
|
1376
|
+
if not document.first():
|
|
1377
|
+
raise HTTPException(status_code=404, detail="Document not found or you do not have permission to delete it.")
|
|
1378
|
+
|
|
1379
|
+
doc_group = document[0].groups[0]
|
|
1380
|
+
group_id = doc_group.group_id
|
|
1381
|
+
doc_name = document[0].title
|
|
1382
|
+
|
|
1383
|
+
document.delete()
|
|
1384
|
+
session.commit()
|
|
1385
|
+
|
|
1386
|
+
log_activity(
|
|
1387
|
+
session=session,
|
|
1388
|
+
user_id=current_user.user_id,
|
|
1389
|
+
group_id=group_id,
|
|
1390
|
+
action="delete",
|
|
1391
|
+
object_id=doc_id,
|
|
1392
|
+
object_name=doc_name,
|
|
1393
|
+
object_type="document"
|
|
1394
|
+
)
|
|
1395
|
+
|
|
1396
|
+
|
|
1397
|
+
@router.post("/process_doc")
|
|
1398
|
+
async def process_doc(
|
|
1399
|
+
doc_id: str = Form(...),
|
|
1400
|
+
doc_text: str = Form(...),
|
|
1401
|
+
generate_feedback: bool = Form(True),
|
|
1402
|
+
current_user: models.User = Depends(get_current_user),
|
|
1403
|
+
analytics: Analytics = Depends(get_analytics),
|
|
1404
|
+
) -> Mapping[str, str | None]:
|
|
1405
|
+
with compair.Session() as session:
|
|
1406
|
+
doc = session.query(models.Document).filter(models.Document.document_id == doc_id).first()
|
|
1407
|
+
if not doc:
|
|
1408
|
+
raise HTTPException(status_code=404, detail="Document not found")
|
|
1409
|
+
# Only allow the author to process/edit
|
|
1410
|
+
if doc.author_id != current_user.user_id:
|
|
1411
|
+
raise HTTPException(status_code=403, detail="Only the author can edit this document")
|
|
1412
|
+
# If the user is suspended, allow user to edit, but not receive any new feedback for docs
|
|
1413
|
+
if current_user.status == "suspended":
|
|
1414
|
+
generate_feedback=False
|
|
1415
|
+
|
|
1416
|
+
if IS_CLOUD:
|
|
1417
|
+
task = process_document_celery.delay(
|
|
1418
|
+
user_id=current_user.user_id,
|
|
1419
|
+
doc_id=doc_id,
|
|
1420
|
+
doc_text=doc_text,
|
|
1421
|
+
generate_feedback=generate_feedback
|
|
1422
|
+
)
|
|
1423
|
+
task_id = task.id
|
|
1424
|
+
else:
|
|
1425
|
+
process_document_celery(
|
|
1426
|
+
user_id=current_user.user_id,
|
|
1427
|
+
doc_id=doc_id,
|
|
1428
|
+
doc_text=doc_text,
|
|
1429
|
+
generate_feedback=generate_feedback
|
|
1430
|
+
)
|
|
1431
|
+
task_id = None
|
|
1432
|
+
|
|
1433
|
+
if generate_feedback:
|
|
1434
|
+
try:
|
|
1435
|
+
analytics.track("feedback_requested", current_user.user_id)
|
|
1436
|
+
except Exception as exc:
|
|
1437
|
+
print(f"analytics track failed: {exc}")
|
|
1438
|
+
return {"task_id": task_id}
|
|
1439
|
+
|
|
1440
|
+
|
|
1441
|
+
@router.post("/upload/ocr-file")
|
|
1442
|
+
async def upload_ocr_file(
|
|
1443
|
+
file: UploadFile = File(...),
|
|
1444
|
+
document_id: str = Form(None),
|
|
1445
|
+
current_user: models.User = Depends(get_current_user),
|
|
1446
|
+
settings: Settings = Depends(get_settings_dependency),
|
|
1447
|
+
ocr: OCRProvider = Depends(get_ocr),
|
|
1448
|
+
):
|
|
1449
|
+
if not settings.ocr_enabled:
|
|
1450
|
+
raise HTTPException(status_code=501, detail="OCR is not available in this edition.")
|
|
1451
|
+
|
|
1452
|
+
file_bytes = await file.read()
|
|
1453
|
+
try:
|
|
1454
|
+
task_id = ocr.submit(
|
|
1455
|
+
user_id=current_user.user_id,
|
|
1456
|
+
filename=file.filename or "upload",
|
|
1457
|
+
data=file_bytes,
|
|
1458
|
+
document_id=document_id,
|
|
1459
|
+
)
|
|
1460
|
+
except NotImplementedError as exc:
|
|
1461
|
+
raise HTTPException(status_code=501, detail=str(exc)) from exc
|
|
1462
|
+
except Exception as exc:
|
|
1463
|
+
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
|
1464
|
+
|
|
1465
|
+
return {"task_id": task_id}
|
|
1466
|
+
|
|
1467
|
+
|
|
1468
|
+
@router.get("/ocr-file-result/{task_id}")
|
|
1469
|
+
def get_ocr_file_result(
|
|
1470
|
+
task_id: str,
|
|
1471
|
+
current_user: models.User = Depends(get_current_user),
|
|
1472
|
+
settings: Settings = Depends(get_settings_dependency),
|
|
1473
|
+
ocr: OCRProvider = Depends(get_ocr),
|
|
1474
|
+
):
|
|
1475
|
+
if not settings.ocr_enabled:
|
|
1476
|
+
raise HTTPException(status_code=501, detail="OCR is not available in this edition.")
|
|
1477
|
+
|
|
1478
|
+
try:
|
|
1479
|
+
status = ocr.status(task_id)
|
|
1480
|
+
except NotImplementedError as exc:
|
|
1481
|
+
raise HTTPException(status_code=501, detail=str(exc)) from exc
|
|
1482
|
+
except Exception as exc:
|
|
1483
|
+
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
|
1484
|
+
|
|
1485
|
+
if isinstance(status, dict):
|
|
1486
|
+
payload = {"task_id": task_id, **status}
|
|
1487
|
+
else:
|
|
1488
|
+
payload = {"task_id": task_id, "status": status}
|
|
1489
|
+
|
|
1490
|
+
return payload
|
|
1491
|
+
|
|
1492
|
+
|
|
1493
|
+
@router.get("/status/{task_id}")
|
|
1494
|
+
async def get_process_status(
|
|
1495
|
+
task_id: str
|
|
1496
|
+
):
|
|
1497
|
+
task_result = AsyncResult(task_id)
|
|
1498
|
+
print(task_result)
|
|
1499
|
+
print(task_result.status)
|
|
1500
|
+
if task_result.status == "SUCCESS":
|
|
1501
|
+
result = task_result.result
|
|
1502
|
+
elif task_result.status == "PENDING":
|
|
1503
|
+
result = "Task is still processing."
|
|
1504
|
+
else:
|
|
1505
|
+
result = "Task failed."
|
|
1506
|
+
return {"task_id": task_id, "status": task_result.status, "result": result}
|
|
1507
|
+
|
|
1508
|
+
|
|
1509
|
+
@router.get("/trial_status")
|
|
1510
|
+
def get_trial_status(
|
|
1511
|
+
current_user: models.User = Depends(get_current_user),
|
|
1512
|
+
billing: BillingProvider = Depends(get_billing),
|
|
1513
|
+
settings: Settings = Depends(get_settings_dependency),
|
|
1514
|
+
):
|
|
1515
|
+
"""Return the trial status for the authenticated user."""
|
|
1516
|
+
print('Trial 1')
|
|
1517
|
+
require_feature(HAS_TRIALS, "Trial management")
|
|
1518
|
+
with compair.Session() as session:
|
|
1519
|
+
trial_expiration = getattr(current_user, "trial_expiration_date", None)
|
|
1520
|
+
if trial_expiration is None:
|
|
1521
|
+
raise HTTPException(status_code=404, detail="Trial expiration date not found.")
|
|
1522
|
+
days_left = (trial_expiration - datetime.now(timezone.utc)).days
|
|
1523
|
+
show_banner = 0 < days_left <= 7
|
|
1524
|
+
|
|
1525
|
+
checkout_url = None
|
|
1526
|
+
if show_banner and settings.billing_enabled:
|
|
1527
|
+
require_feature(HAS_BILLING, "Billing integration")
|
|
1528
|
+
db_user = session.query(models.User).filter(models.User.user_id == current_user.user_id).first()
|
|
1529
|
+
if db_user:
|
|
1530
|
+
if not db_user.stripe_customer_id:
|
|
1531
|
+
try:
|
|
1532
|
+
customer_id = billing.ensure_customer(
|
|
1533
|
+
user_email=db_user.username,
|
|
1534
|
+
user_id=db_user.user_id,
|
|
1535
|
+
)
|
|
1536
|
+
except NotImplementedError:
|
|
1537
|
+
customer_id = None
|
|
1538
|
+
else:
|
|
1539
|
+
db_user.stripe_customer_id = customer_id
|
|
1540
|
+
session.commit()
|
|
1541
|
+
customer_id = db_user.stripe_customer_id
|
|
1542
|
+
if customer_id:
|
|
1543
|
+
try:
|
|
1544
|
+
session_info = billing.create_checkout_session(
|
|
1545
|
+
customer_id=customer_id,
|
|
1546
|
+
price_id=plan_to_id.get("individual_monthly"),
|
|
1547
|
+
qty=1,
|
|
1548
|
+
success_url=settings.stripe_success_url,
|
|
1549
|
+
cancel_url=settings.stripe_cancel_url,
|
|
1550
|
+
metadata={"plan": "individual_monthly", "user_id": db_user.user_id},
|
|
1551
|
+
)
|
|
1552
|
+
checkout_url = getattr(session_info, "url", None)
|
|
1553
|
+
if checkout_url is None and hasattr(session_info, "get"):
|
|
1554
|
+
checkout_url = session_info.get("url")
|
|
1555
|
+
except NotImplementedError:
|
|
1556
|
+
checkout_url = None
|
|
1557
|
+
except Exception as exc:
|
|
1558
|
+
print(f"Error generating Checkout URL: {exc}")
|
|
1559
|
+
|
|
1560
|
+
return {
|
|
1561
|
+
"status": current_user.status,
|
|
1562
|
+
"days_left": days_left,
|
|
1563
|
+
"show_banner": show_banner,
|
|
1564
|
+
"checkout_url": checkout_url,
|
|
1565
|
+
}
|
|
1566
|
+
|
|
1567
|
+
|
|
1568
|
+
@router.get("/load_chunks")
|
|
1569
|
+
def load_chunks(
|
|
1570
|
+
document_id: str,
|
|
1571
|
+
) -> list[schema.Chunk]:
|
|
1572
|
+
with compair.Session() as session:
|
|
1573
|
+
chunks = session.query(models.Chunk).filter(
|
|
1574
|
+
models.Chunk.document_id==document_id
|
|
1575
|
+
)
|
|
1576
|
+
return chunks
|
|
1577
|
+
|
|
1578
|
+
|
|
1579
|
+
@router.get("/load_feedback")
|
|
1580
|
+
def load_feedback(
|
|
1581
|
+
chunk_id: str,
|
|
1582
|
+
) -> schema.Feedback | None:
|
|
1583
|
+
with compair.Session() as session:
|
|
1584
|
+
feedback = session.query(models.Feedback).filter(
|
|
1585
|
+
models.Feedback.source_chunk_id==chunk_id
|
|
1586
|
+
)
|
|
1587
|
+
if len(feedback.all())>0:
|
|
1588
|
+
f = feedback[0]
|
|
1589
|
+
# Return user_feedback in the response
|
|
1590
|
+
return {
|
|
1591
|
+
"feedback_id": f.feedback_id,
|
|
1592
|
+
"source_chunk_id": f.source_chunk_id,
|
|
1593
|
+
"feedback": f.feedback,
|
|
1594
|
+
"user_feedback": f.user_feedback,
|
|
1595
|
+
"is_hidden": f.is_hidden,
|
|
1596
|
+
}
|
|
1597
|
+
return None
|
|
1598
|
+
|
|
1599
|
+
|
|
1600
|
+
@router.post("/feedback/{feedback_id}/hide")
|
|
1601
|
+
def hide_feedback(
|
|
1602
|
+
feedback_id: str,
|
|
1603
|
+
is_hidden: bool = Form(False)
|
|
1604
|
+
):
|
|
1605
|
+
with compair.Session() as session:
|
|
1606
|
+
feedback = session.query(models.Feedback).filter(models.Feedback.feedback_id == feedback_id).first()
|
|
1607
|
+
if not feedback:
|
|
1608
|
+
raise HTTPException(status_code=404, detail="Feedback not found")
|
|
1609
|
+
feedback.is_hidden = is_hidden
|
|
1610
|
+
session.commit()
|
|
1611
|
+
return {"message": f"Feedback {'hidden' if is_hidden else 'unhidden'} successfully", "feedback_id": feedback_id}
|
|
1612
|
+
|
|
1613
|
+
|
|
1614
|
+
|
|
1615
|
+
@router.post("/feedback/{feedback_id}/rate")
|
|
1616
|
+
def rate_feedback(
|
|
1617
|
+
feedback_id: str,
|
|
1618
|
+
user_feedback: str = Body(..., embed=True) # expects "positive" or "negative"
|
|
1619
|
+
):
|
|
1620
|
+
if user_feedback not in ("positive", "negative", None):
|
|
1621
|
+
raise HTTPException(status_code=400, detail="Invalid feedback label")
|
|
1622
|
+
with compair.Session() as session:
|
|
1623
|
+
feedback = session.query(models.Feedback).filter(models.Feedback.feedback_id == feedback_id).first()
|
|
1624
|
+
if not feedback:
|
|
1625
|
+
raise HTTPException(status_code=404, detail="Feedback not found")
|
|
1626
|
+
feedback.user_feedback = user_feedback
|
|
1627
|
+
session.commit()
|
|
1628
|
+
return {"message": "Feedback rated successfully", "feedback_id": feedback_id, "user_feedback": user_feedback}
|
|
1629
|
+
|
|
1630
|
+
|
|
1631
|
+
@router.get("/documents/{document_id}/feedback")
|
|
1632
|
+
def list_document_feedback(
|
|
1633
|
+
document_id: str,
|
|
1634
|
+
current_user: models.User = Depends(get_current_user)
|
|
1635
|
+
):
|
|
1636
|
+
"""Return feedback entries for a document the current user can access."""
|
|
1637
|
+
with compair.Session() as session:
|
|
1638
|
+
doc = session.query(models.Document).filter(models.Document.document_id == document_id).first()
|
|
1639
|
+
if not doc:
|
|
1640
|
+
raise HTTPException(status_code=404, detail="Document not found")
|
|
1641
|
+
# Access control: owner, member of a group with access, or published
|
|
1642
|
+
user_group_ids = {g.group_id for g in current_user.groups}
|
|
1643
|
+
doc_group_ids = {g.group_id for g in doc.groups}
|
|
1644
|
+
if not (doc.user_id == current_user.user_id or doc_group_ids & user_group_ids or doc.is_published):
|
|
1645
|
+
raise HTTPException(status_code=403, detail="Not authorized to access this document")
|
|
1646
|
+
# Join through chunks
|
|
1647
|
+
q = (
|
|
1648
|
+
session.query(models.Feedback)
|
|
1649
|
+
.join(models.Chunk, models.Feedback.source_chunk_id == models.Chunk.chunk_id)
|
|
1650
|
+
.filter(models.Chunk.document_id == document_id)
|
|
1651
|
+
.order_by(models.Feedback.timestamp.desc())
|
|
1652
|
+
)
|
|
1653
|
+
rows = q.all()
|
|
1654
|
+
return {
|
|
1655
|
+
"document_id": document_id,
|
|
1656
|
+
"count": len(rows),
|
|
1657
|
+
"feedback": [
|
|
1658
|
+
{
|
|
1659
|
+
"feedback_id": f.feedback_id,
|
|
1660
|
+
"chunk_id": f.source_chunk_id,
|
|
1661
|
+
"feedback": f.feedback,
|
|
1662
|
+
"user_feedback": f.user_feedback,
|
|
1663
|
+
"timestamp": f.timestamp,
|
|
1664
|
+
} for f in rows
|
|
1665
|
+
],
|
|
1666
|
+
}
|
|
1667
|
+
|
|
1668
|
+
|
|
1669
|
+
@router.get("/load_references")
|
|
1670
|
+
def load_references(
|
|
1671
|
+
chunk_id: str,
|
|
1672
|
+
) -> list[schema.Reference]:
|
|
1673
|
+
with compair.Session() as session:
|
|
1674
|
+
references = session.query(models.Reference).filter(
|
|
1675
|
+
models.Reference.source_chunk_id==chunk_id
|
|
1676
|
+
).all()
|
|
1677
|
+
returned_references: list[schema.Reference] = [
|
|
1678
|
+
schema.Reference(
|
|
1679
|
+
reference_id=r.reference_id,
|
|
1680
|
+
source_chunk_id=r.source_chunk_id,
|
|
1681
|
+
reference_document_id=r.reference_document_id,
|
|
1682
|
+
document=schema.Document(
|
|
1683
|
+
document_id=r.document.document_id,
|
|
1684
|
+
user_id=r.document.user_id,
|
|
1685
|
+
author_id=r.document.author_id,
|
|
1686
|
+
title=r.document.title,
|
|
1687
|
+
content=r.document.content,
|
|
1688
|
+
doc_type=r.document.doc_type,
|
|
1689
|
+
datetime_created=r.document.datetime_created,
|
|
1690
|
+
datetime_modified=r.document.datetime_modified,
|
|
1691
|
+
is_published=r.document.is_published,
|
|
1692
|
+
groups=[{"group_id":"","datetime_created":datetime.now(),"name":""}],
|
|
1693
|
+
user=schema.User(
|
|
1694
|
+
user_id=r.document.user.user_id,
|
|
1695
|
+
username=r.document.user.username,
|
|
1696
|
+
name=r.document.user.name,
|
|
1697
|
+
groups=[{"group_id":"","datetime_created":datetime.now(),"name":""}],
|
|
1698
|
+
datetime_registered=r.document.user.datetime_registered,
|
|
1699
|
+
status=r.document.user.status
|
|
1700
|
+
)
|
|
1701
|
+
),
|
|
1702
|
+
document_author=r.document.user.username
|
|
1703
|
+
) for r in references
|
|
1704
|
+
]
|
|
1705
|
+
print(f'Returned references: {returned_references}')
|
|
1706
|
+
return returned_references
|
|
1707
|
+
|
|
1708
|
+
@router.get("/verify-email")
|
|
1709
|
+
def verify_email(token: str):
|
|
1710
|
+
with compair.Session() as session:
|
|
1711
|
+
print(token)
|
|
1712
|
+
user = session.query(models.User).filter(models.User.verification_token == token).first()
|
|
1713
|
+
print(user)
|
|
1714
|
+
print(user.token_expiration)
|
|
1715
|
+
print(datetime.now(timezone.utc))
|
|
1716
|
+
if not user:
|
|
1717
|
+
raise HTTPException(status_code=400, detail="Invalid or expired token")
|
|
1718
|
+
if user.token_expiration < datetime.now(timezone.utc):
|
|
1719
|
+
raise HTTPException(status_code=400, detail="Token has expired")
|
|
1720
|
+
user.status = 'trial'
|
|
1721
|
+
user.status_change_date = datetime.now(timezone.utc)
|
|
1722
|
+
if HAS_TRIALS:
|
|
1723
|
+
user.trial_expiration_date = datetime.now(timezone.utc) + timedelta(days=30)
|
|
1724
|
+
user.verification_token = None # Invalidate the token
|
|
1725
|
+
session.commit()
|
|
1726
|
+
|
|
1727
|
+
pending_invitations = session.query(models.GroupInvitation).filter(
|
|
1728
|
+
models.GroupInvitation.email == user.username,
|
|
1729
|
+
models.GroupInvitation.status == "pending"
|
|
1730
|
+
).all()
|
|
1731
|
+
for invitation in pending_invitations:
|
|
1732
|
+
# Mark as "sent"
|
|
1733
|
+
invitation.status = "sent"
|
|
1734
|
+
# Generate invitation token if not present
|
|
1735
|
+
if not invitation.token:
|
|
1736
|
+
invitation.token = secrets.token_urlsafe(32).lower()
|
|
1737
|
+
invitation.datetime_expiration = datetime.now(timezone.utc) + timedelta(days=7)
|
|
1738
|
+
session.commit()
|
|
1739
|
+
|
|
1740
|
+
# Send the actual group invitation email
|
|
1741
|
+
group = invitation.group
|
|
1742
|
+
inviter = invitation.inviter
|
|
1743
|
+
invitation_link = f"http://{WEB_URL}/accept-group-invitation?token={invitation.token}&user_id={user.user_id}"
|
|
1744
|
+
emailer.connect()
|
|
1745
|
+
emailer.send(
|
|
1746
|
+
subject="You’re Invited to Join a Group on Compair",
|
|
1747
|
+
sender=EMAIL_USER,
|
|
1748
|
+
receivers=[user.username],
|
|
1749
|
+
html=GROUP_INVITATION_TEMPLATE.replace(
|
|
1750
|
+
"{{inviter_name}}", inviter.name
|
|
1751
|
+
).replace(
|
|
1752
|
+
"{{group_name}}", group.name
|
|
1753
|
+
).replace(
|
|
1754
|
+
"{{invitation_link}}", invitation_link
|
|
1755
|
+
)
|
|
1756
|
+
)
|
|
1757
|
+
|
|
1758
|
+
return {"message": "Email verified successfully. Your free trial has started!"}
|
|
1759
|
+
|
|
1760
|
+
def is_valid_email(email):
|
|
1761
|
+
return re.match(r"[^@]+@[^@]+\.[^@]+", email)
|
|
1762
|
+
|
|
1763
|
+
@router.post("/sign-up")
|
|
1764
|
+
def sign_up(
|
|
1765
|
+
request: schema.SignUpRequest,
|
|
1766
|
+
analytics: Analytics = Depends(get_analytics),
|
|
1767
|
+
) -> dict:
|
|
1768
|
+
print('1')
|
|
1769
|
+
if not is_valid_email(request.username):
|
|
1770
|
+
raise HTTPException(status_code=400, detail="Invalid email address")
|
|
1771
|
+
with compair.Session() as session:
|
|
1772
|
+
print('2')
|
|
1773
|
+
# Call internal function to create the user
|
|
1774
|
+
#user = create_user(email=request.email, password=request.password)
|
|
1775
|
+
user = create_user(
|
|
1776
|
+
username=request.username,
|
|
1777
|
+
name=request.name,
|
|
1778
|
+
password=request.password,
|
|
1779
|
+
groups=request.groups,
|
|
1780
|
+
session=session,
|
|
1781
|
+
referral_code=request.referral_code,
|
|
1782
|
+
)
|
|
1783
|
+
print('Passed create_user')
|
|
1784
|
+
# Generate the verification link
|
|
1785
|
+
verification_link = f"http://{WEB_URL}/verify-email?token={user.verification_token}"
|
|
1786
|
+
print('3?')
|
|
1787
|
+
emailer.connect()
|
|
1788
|
+
print('4??')
|
|
1789
|
+
# Send the verification email
|
|
1790
|
+
emailer.send(
|
|
1791
|
+
subject="Verify your email address",
|
|
1792
|
+
sender=EMAIL_USER,
|
|
1793
|
+
receivers=[user.username],
|
|
1794
|
+
html=ACCOUNT_VERIFY_TEMPLATE.replace("{{verification_link}}", verification_link)
|
|
1795
|
+
)
|
|
1796
|
+
print('The end???')
|
|
1797
|
+
return {"message": "Sign-up successful. Please check your email for verification."}
|
|
1798
|
+
|
|
1799
|
+
@router.post("/forgot-password")
|
|
1800
|
+
def forgot_password(request: schema.ForgotPasswordRequest) -> dict:
|
|
1801
|
+
print('1')
|
|
1802
|
+
with compair.Session() as session:
|
|
1803
|
+
print('2')
|
|
1804
|
+
user = session.query(models.User).filter(models.User.username == request.email).first()
|
|
1805
|
+
print(user)
|
|
1806
|
+
if not user:
|
|
1807
|
+
return {"message": "If the email exists, a reset link will be sent."}
|
|
1808
|
+
|
|
1809
|
+
# Generate reset token
|
|
1810
|
+
token, expiration = generate_verification_token() # Same function as before
|
|
1811
|
+
print(token)
|
|
1812
|
+
token = token.lower()
|
|
1813
|
+
user.reset_token = token
|
|
1814
|
+
user.token_expiration = expiration
|
|
1815
|
+
session.commit()
|
|
1816
|
+
|
|
1817
|
+
print('3')
|
|
1818
|
+
# Send email with reset link
|
|
1819
|
+
reset_link = f"http://{WEB_URL}/reset-password?token={token}"
|
|
1820
|
+
emailer.connect()
|
|
1821
|
+
print('4')
|
|
1822
|
+
emailer.send(
|
|
1823
|
+
subject="Password Reset Request",
|
|
1824
|
+
sender=EMAIL_USER,
|
|
1825
|
+
receivers=[request.email],
|
|
1826
|
+
html=PASSWORD_RESET_TEMPLATE.replace("{{reset_link}}", reset_link)
|
|
1827
|
+
)
|
|
1828
|
+
print('5')
|
|
1829
|
+
return {"message": "If the email exists, a reset link will be sent."}
|
|
1830
|
+
|
|
1831
|
+
@router.post("/reset-password")
|
|
1832
|
+
def reset_password(request: schema.ResetPasswordRequest) -> dict:
|
|
1833
|
+
with compair.Session() as session:
|
|
1834
|
+
print('1')
|
|
1835
|
+
print(request.token)
|
|
1836
|
+
user = session.query(models.User).filter(models.User.reset_token == request.token).first()
|
|
1837
|
+
print(user)
|
|
1838
|
+
if not user or user.token_expiration < datetime.now(timezone.utc):
|
|
1839
|
+
raise HTTPException(status_code=400, detail="Invalid or expired token")
|
|
1840
|
+
|
|
1841
|
+
# Update the password
|
|
1842
|
+
user.set_password(request.new_password)
|
|
1843
|
+
print('2')
|
|
1844
|
+
user.reset_token = None # Invalidate the token
|
|
1845
|
+
user.token_expiration = None
|
|
1846
|
+
print('3')
|
|
1847
|
+
session.commit()
|
|
1848
|
+
print('4')
|
|
1849
|
+
|
|
1850
|
+
return {"message": "Password has been reset successfully"}
|
|
1851
|
+
|
|
1852
|
+
@router.get("/admin/groups")
|
|
1853
|
+
def get_admin_groups(
|
|
1854
|
+
current_user: models.User = Depends(get_current_user),
|
|
1855
|
+
):
|
|
1856
|
+
"""Retrieve groups managed by the given user (admin)."""
|
|
1857
|
+
with compair.Session() as session:
|
|
1858
|
+
admin = session.query(models.Administrator).filter(models.Administrator.user_id == current_user.user_id).first()
|
|
1859
|
+
if not admin:
|
|
1860
|
+
return [] # Not an admin of any groups
|
|
1861
|
+
|
|
1862
|
+
# Only return groups where this user is an admin
|
|
1863
|
+
groups = admin.groups
|
|
1864
|
+
return [
|
|
1865
|
+
{
|
|
1866
|
+
"group_id": g.group_id,
|
|
1867
|
+
"name": g.name,
|
|
1868
|
+
"visibility": g.visibility,
|
|
1869
|
+
"category": g.category,
|
|
1870
|
+
"description": g.description,
|
|
1871
|
+
"group_image": g.group_image,
|
|
1872
|
+
"datetime_created": g.datetime_created,
|
|
1873
|
+
}
|
|
1874
|
+
for g in groups
|
|
1875
|
+
]
|
|
1876
|
+
|
|
1877
|
+
@router.get("/admin/join_requests")
|
|
1878
|
+
def get_join_requests(
|
|
1879
|
+
group_id: str,
|
|
1880
|
+
current_user: models.User = Depends(get_current_user)
|
|
1881
|
+
):
|
|
1882
|
+
"""Retrieve pending join requests for a group."""
|
|
1883
|
+
with compair.Session() as session:
|
|
1884
|
+
requests = session.query(models.JoinRequest).filter(
|
|
1885
|
+
models.JoinRequest.group_id == group_id
|
|
1886
|
+
).filter(
|
|
1887
|
+
models.JoinRequest.group.has(models.Group.admins.any(models.Administrator.user_id == current_user.user_id))
|
|
1888
|
+
).all()
|
|
1889
|
+
return [
|
|
1890
|
+
{"request_id": r.request_id, "user_name": r.user.name, "datetime_requested": r.datetime_requested}
|
|
1891
|
+
for r in requests
|
|
1892
|
+
]
|
|
1893
|
+
|
|
1894
|
+
@router.post("/admin/approve_request")
|
|
1895
|
+
def approve_request(
|
|
1896
|
+
request_id: int,
|
|
1897
|
+
current_user: models.User = Depends(get_current_user)
|
|
1898
|
+
):
|
|
1899
|
+
"""Approve a join request."""
|
|
1900
|
+
with compair.Session() as session:
|
|
1901
|
+
request = session.query(models.JoinRequest).filter(
|
|
1902
|
+
models.JoinRequest.request_id == request_id
|
|
1903
|
+
).filter(
|
|
1904
|
+
models.JoinRequest.group.has(models.Group.admins.any(models.Administrator.user_id == current_user.user_id))
|
|
1905
|
+
).first()
|
|
1906
|
+
if not request:
|
|
1907
|
+
raise HTTPException(status_code=404, detail="Request not found")
|
|
1908
|
+
|
|
1909
|
+
group: models.Group = request.group
|
|
1910
|
+
user: models.User = request.user
|
|
1911
|
+
group.users.append(user)
|
|
1912
|
+
session.delete(request)
|
|
1913
|
+
session.commit()
|
|
1914
|
+
|
|
1915
|
+
log_activity(
|
|
1916
|
+
session=session,
|
|
1917
|
+
user_id=user.user_id,
|
|
1918
|
+
group_id=group.group_id,
|
|
1919
|
+
action="join",
|
|
1920
|
+
object_id=group.group_id,
|
|
1921
|
+
object_name=group.name,
|
|
1922
|
+
object_type="group"
|
|
1923
|
+
)
|
|
1924
|
+
|
|
1925
|
+
return {"message": "Request approved successfully"}
|
|
1926
|
+
|
|
1927
|
+
@router.post("/admin/reject_request")
|
|
1928
|
+
def reject_request(
|
|
1929
|
+
request_id: int,
|
|
1930
|
+
current_user: models.User = Depends(get_current_user)
|
|
1931
|
+
):
|
|
1932
|
+
"""Reject a join request."""
|
|
1933
|
+
with compair.Session() as session:
|
|
1934
|
+
request = session.query(models.JoinRequest).filter(
|
|
1935
|
+
models.JoinRequest.request_id == request_id,
|
|
1936
|
+
models.JoinRequest.group.has(models.Group.admins.any(models.Administrator.user_id == current_user.user_id))
|
|
1937
|
+
).first()
|
|
1938
|
+
if not request:
|
|
1939
|
+
raise HTTPException(status_code=404, detail="Request not found")
|
|
1940
|
+
|
|
1941
|
+
session.delete(request)
|
|
1942
|
+
session.commit()
|
|
1943
|
+
|
|
1944
|
+
return {"message": "Request rejected successfully"}
|
|
1945
|
+
|
|
1946
|
+
@router.post("/admin/update_group")
|
|
1947
|
+
def update_group(
|
|
1948
|
+
group_id: str,
|
|
1949
|
+
name: Optional[str] = Form(None),
|
|
1950
|
+
visibility: Optional[str] = Form(None),
|
|
1951
|
+
category: Optional[str] = Form(None),
|
|
1952
|
+
description: Optional[str] = Form(None),
|
|
1953
|
+
current_user: models.User = Depends(get_current_user)
|
|
1954
|
+
) -> dict:
|
|
1955
|
+
"""Update group settings."""
|
|
1956
|
+
print('1')
|
|
1957
|
+
with compair.Session() as session:
|
|
1958
|
+
group = session.query(models.Group).filter(
|
|
1959
|
+
models.Group.group_id == group_id,
|
|
1960
|
+
models.Group.admins.any(models.Administrator.user_id == current_user.user_id)
|
|
1961
|
+
).first()
|
|
1962
|
+
print('2')
|
|
1963
|
+
if not group:
|
|
1964
|
+
raise HTTPException(status_code=404, detail="Group not found")
|
|
1965
|
+
print(group)
|
|
1966
|
+
print(group_id)
|
|
1967
|
+
print(name)
|
|
1968
|
+
print(visibility)
|
|
1969
|
+
print(category)
|
|
1970
|
+
print(description)
|
|
1971
|
+
if name:
|
|
1972
|
+
group.name = name
|
|
1973
|
+
if visibility:
|
|
1974
|
+
group.visibility = visibility
|
|
1975
|
+
if category:
|
|
1976
|
+
group.category = category
|
|
1977
|
+
if description:
|
|
1978
|
+
group.description = description
|
|
1979
|
+
session.commit()
|
|
1980
|
+
|
|
1981
|
+
return {"message": "Group updated successfully"}
|
|
1982
|
+
|
|
1983
|
+
@router.post("/admin/invite_member")
|
|
1984
|
+
def invite_member_to_group(
|
|
1985
|
+
request: schema.InviteMemberRequest,
|
|
1986
|
+
current_user: models.User = Depends(get_current_user)
|
|
1987
|
+
):
|
|
1988
|
+
"""
|
|
1989
|
+
Invite an existing Compair user (by username or email) to join a group.
|
|
1990
|
+
If the user exists, create a GroupInvitation and send an email.
|
|
1991
|
+
"""
|
|
1992
|
+
admin_id = request.admin_id,
|
|
1993
|
+
group_id = request.group_id,
|
|
1994
|
+
username = request.username
|
|
1995
|
+
with compair.Session() as session:
|
|
1996
|
+
group = session.query(models.Group).filter(
|
|
1997
|
+
models.Group.group_id == group_id,
|
|
1998
|
+
models.Group.admins.any(models.Administrator.user_id == admin_id)
|
|
1999
|
+
).first()
|
|
2000
|
+
if not group:
|
|
2001
|
+
raise HTTPException(status_code=404, detail="Group not found")
|
|
2002
|
+
|
|
2003
|
+
admin = session.query(models.Administrator).filter(
|
|
2004
|
+
models.Administrator.user_id == admin_id,
|
|
2005
|
+
models.Administrator.groups.contains(group)
|
|
2006
|
+
).first()
|
|
2007
|
+
if not admin:
|
|
2008
|
+
raise HTTPException(status_code=403, detail="You are not authorized to manage this group")
|
|
2009
|
+
|
|
2010
|
+
# Try to find user by username or email
|
|
2011
|
+
user = session.query(models.User).filter(
|
|
2012
|
+
(models.User.username == username)# | (models.User.email == username_or_email) # if they were different
|
|
2013
|
+
).first()
|
|
2014
|
+
if not user:
|
|
2015
|
+
raise HTTPException(status_code=404, detail="User not found")
|
|
2016
|
+
|
|
2017
|
+
# Check if already a member or already invited
|
|
2018
|
+
if user in group.users:
|
|
2019
|
+
raise HTTPException(status_code=400, detail="User is already a member")
|
|
2020
|
+
existing_invite = session.query(models.GroupInvitation).filter(
|
|
2021
|
+
models.GroupInvitation.group_id == group_id,
|
|
2022
|
+
models.GroupInvitation.email == user.username
|
|
2023
|
+
).first()
|
|
2024
|
+
if existing_invite:
|
|
2025
|
+
raise HTTPException(status_code=400, detail="User already invited")
|
|
2026
|
+
|
|
2027
|
+
# Create invitation
|
|
2028
|
+
token = secrets.token_urlsafe(32).lower()
|
|
2029
|
+
invitation = models.GroupInvitation(
|
|
2030
|
+
group_id=group_id,
|
|
2031
|
+
inviter_id=admin_id,
|
|
2032
|
+
token=token,
|
|
2033
|
+
email=user.username,
|
|
2034
|
+
datetime_expiration=datetime.utcnow() + timedelta(days=7),
|
|
2035
|
+
status='pending'
|
|
2036
|
+
)
|
|
2037
|
+
session.add(invitation)
|
|
2038
|
+
session.commit()
|
|
2039
|
+
|
|
2040
|
+
# Send email
|
|
2041
|
+
invitation_link = f"http://{WEB_URL}/accept-group-invitation?token={token}&user_id={user.user_id}"
|
|
2042
|
+
emailer.connect()
|
|
2043
|
+
emailer.send(
|
|
2044
|
+
subject="You’re Invited to Join a Group on Compair",
|
|
2045
|
+
sender=EMAIL_USER,
|
|
2046
|
+
receivers=[user.username],
|
|
2047
|
+
html=GROUP_INVITATION_TEMPLATE.replace(
|
|
2048
|
+
"{{ inviter_name }}", admin.user.name
|
|
2049
|
+
).replace(
|
|
2050
|
+
"{{ group_name }}", group.name
|
|
2051
|
+
).replace(
|
|
2052
|
+
"{{ invitation_link }}", invitation_link
|
|
2053
|
+
)
|
|
2054
|
+
)
|
|
2055
|
+
invitation.status = 'sent'
|
|
2056
|
+
session.commit()
|
|
2057
|
+
return {"message": "Invitation sent successfully"}
|
|
2058
|
+
|
|
2059
|
+
@router.post("/admin/invite_new_user")
|
|
2060
|
+
def invite_new_user_to_group(
|
|
2061
|
+
request: schema.InviteToGroupRequest,
|
|
2062
|
+
current_user: models.User = Depends(get_current_user)
|
|
2063
|
+
) -> dict:
|
|
2064
|
+
admin_id = request.admin_id
|
|
2065
|
+
group_id = request.group_id
|
|
2066
|
+
email = request.email
|
|
2067
|
+
"""Generate an invitation link to Compair and send it via email, logging a Group request to send on signup."""
|
|
2068
|
+
with compair.Session() as session:
|
|
2069
|
+
group = session.query(models.Group).filter(
|
|
2070
|
+
models.Group.group_id == group_id,
|
|
2071
|
+
models.Group.admins.any(models.Administrator.user_id == admin_id)
|
|
2072
|
+
).first()
|
|
2073
|
+
if not group:
|
|
2074
|
+
raise HTTPException(status_code=404, detail="Group not found")
|
|
2075
|
+
|
|
2076
|
+
admin = session.query(models.Administrator).filter(
|
|
2077
|
+
models.Administrator.user_id == admin_id,
|
|
2078
|
+
models.Administrator.groups.contains(group)
|
|
2079
|
+
).first()
|
|
2080
|
+
if not admin:
|
|
2081
|
+
raise HTTPException(status_code=403, detail="You are not authorized to manage this group")
|
|
2082
|
+
|
|
2083
|
+
referral_link = generate_referral_link(admin.user.referral_code)
|
|
2084
|
+
|
|
2085
|
+
# Send email notification
|
|
2086
|
+
emailer.connect()
|
|
2087
|
+
# Track pending group invitation for this email if group_id is provided
|
|
2088
|
+
if group_id:
|
|
2089
|
+
# Store a "pending" group invitation for this email
|
|
2090
|
+
token = secrets.token_urlsafe(32).lower()
|
|
2091
|
+
invitation = models.GroupInvitation(
|
|
2092
|
+
group_id=group_id,
|
|
2093
|
+
inviter_id=admin.user.user_id,
|
|
2094
|
+
token=token,
|
|
2095
|
+
email=email,
|
|
2096
|
+
datetime_expiration=datetime.utcnow() + timedelta(days=7),
|
|
2097
|
+
status="pending"
|
|
2098
|
+
)
|
|
2099
|
+
session.add(invitation)
|
|
2100
|
+
session.commit()
|
|
2101
|
+
# Send email notification
|
|
2102
|
+
emailer.send(
|
|
2103
|
+
subject="You're Invited to Compair!",
|
|
2104
|
+
sender=EMAIL_USER,
|
|
2105
|
+
receivers=[email],
|
|
2106
|
+
html=INDIVIDUAL_INVITATION_TEMPLATE.replace(
|
|
2107
|
+
"{{inviter_name}}", admin.user.name
|
|
2108
|
+
).replace(
|
|
2109
|
+
"{{referral_link}}", referral_link,
|
|
2110
|
+
)
|
|
2111
|
+
)
|
|
2112
|
+
|
|
2113
|
+
return {"message": "Invitation to join Compair successful."}
|
|
2114
|
+
|
|
2115
|
+
@router.post("/admin/remove_member")
|
|
2116
|
+
def remove_member(
|
|
2117
|
+
request: schema.RemoveMemberRequest,
|
|
2118
|
+
current_user: models.User = Depends(get_current_user)
|
|
2119
|
+
):
|
|
2120
|
+
"""Remove a member from a group."""
|
|
2121
|
+
with compair.Session() as session:
|
|
2122
|
+
# Validate that the group exists
|
|
2123
|
+
group = session.query(models.Group).filter(models.Group.group_id == request.group_id).first()
|
|
2124
|
+
if not group:
|
|
2125
|
+
raise HTTPException(status_code=404, detail="Group not found")
|
|
2126
|
+
|
|
2127
|
+
# Validate that the current user is an admin of the group
|
|
2128
|
+
if not any(admin.user_id == current_user.user_id for admin in group.admins):
|
|
2129
|
+
raise HTTPException(status_code=403, detail="You are not authorized to manage this group")
|
|
2130
|
+
|
|
2131
|
+
# Validate that the user to be removed is a member of the group
|
|
2132
|
+
user = session.query(models.User).filter(models.User.user_id == request.user_id).first()
|
|
2133
|
+
if not user or user not in group.users:
|
|
2134
|
+
raise HTTPException(status_code=404, detail="User is not a member of this group")
|
|
2135
|
+
|
|
2136
|
+
# Remove the user from the group
|
|
2137
|
+
group.users.remove(user)
|
|
2138
|
+
session.commit()
|
|
2139
|
+
|
|
2140
|
+
# Log the activity
|
|
2141
|
+
log_activity(
|
|
2142
|
+
session=session,
|
|
2143
|
+
user_id=current_user.user_id,
|
|
2144
|
+
group_id=group.group_id,
|
|
2145
|
+
action="remove_member",
|
|
2146
|
+
object_id=user.user_id,
|
|
2147
|
+
object_name=user.name,
|
|
2148
|
+
object_type="user"
|
|
2149
|
+
)
|
|
2150
|
+
|
|
2151
|
+
return {"message": f"User {user.name} has been removed from the group"}
|
|
2152
|
+
|
|
2153
|
+
@router.get("/accept_group_invitation")
|
|
2154
|
+
def accept_group_invitation(
|
|
2155
|
+
token: str,
|
|
2156
|
+
current_user: models.User = Depends(get_current_user)
|
|
2157
|
+
):
|
|
2158
|
+
"""Accept a group invitation using a token."""
|
|
2159
|
+
with compair.Session() as session:
|
|
2160
|
+
invitation = session.query(models.GroupInvitation).filter(models.GroupInvitation.token == token).first()
|
|
2161
|
+
if not invitation:
|
|
2162
|
+
raise HTTPException(status_code=404, detail="Invalid or expired invitation")
|
|
2163
|
+
|
|
2164
|
+
if invitation.datetime_expiration < datetime.now(timezone.utc):
|
|
2165
|
+
invitation.status = "expired"
|
|
2166
|
+
session.commit()
|
|
2167
|
+
raise HTTPException(status_code=400, detail="Invitation has expired")
|
|
2168
|
+
|
|
2169
|
+
group = invitation.group
|
|
2170
|
+
user = session.query(models.User).filter(models.User.user_id == current_user.user_id).first()
|
|
2171
|
+
|
|
2172
|
+
# Add the user to the group
|
|
2173
|
+
group.users.append(user)
|
|
2174
|
+
invitation.status = "accepted"
|
|
2175
|
+
session.commit()
|
|
2176
|
+
|
|
2177
|
+
return {"message": "You have successfully joined the group"}
|
|
2178
|
+
|
|
2179
|
+
@router.get("/accept_team_invitation")
|
|
2180
|
+
def accept_team_invitation(
|
|
2181
|
+
token: str,
|
|
2182
|
+
current_user: models.User = Depends(get_current_user)
|
|
2183
|
+
):
|
|
2184
|
+
"""Accept a team invitation using a token."""
|
|
2185
|
+
require_feature(HAS_TEAM, "Team collaboration")
|
|
2186
|
+
with compair.Session() as session:
|
|
2187
|
+
invitation = session.query(models.TeamInvitation).filter(
|
|
2188
|
+
models.TeamInvitation.invitation_id == token,
|
|
2189
|
+
models.TeamInvitation.status == "sent",
|
|
2190
|
+
).first()
|
|
2191
|
+
if not invitation:
|
|
2192
|
+
raise HTTPException(status_code=404, detail="Invalid or expired invitation.")
|
|
2193
|
+
|
|
2194
|
+
if hasattr(current_user, "team_id"):
|
|
2195
|
+
current_user.team_id = invitation.team_id
|
|
2196
|
+
invitation.status = "accepted"
|
|
2197
|
+
session.commit()
|
|
2198
|
+
|
|
2199
|
+
return {"message": "You have successfully joined the team!"}
|
|
2200
|
+
|
|
2201
|
+
def get_feedback_tooltip(
|
|
2202
|
+
feedback_id: str,
|
|
2203
|
+
) -> str:
|
|
2204
|
+
"""Retrieve feedback details for tooltip display."""
|
|
2205
|
+
with compair.Session() as session:
|
|
2206
|
+
feedbacks = session.query(models.Feedback).filter(models.Feedback.feedback_id == feedback_id).all()
|
|
2207
|
+
return " | ".join(f.feedback for f in feedbacks)
|
|
2208
|
+
|
|
2209
|
+
@router.get("/get_activity_feed")
|
|
2210
|
+
def get_activity_feed(
|
|
2211
|
+
user_id: str,
|
|
2212
|
+
page: int = 1,
|
|
2213
|
+
page_size: int = 10,
|
|
2214
|
+
include_own_activities: bool = True,
|
|
2215
|
+
current_user: models.User = Depends(get_current_user)
|
|
2216
|
+
):
|
|
2217
|
+
"""Retrieve recent activities for a user's groups."""
|
|
2218
|
+
require_feature(HAS_ACTIVITY, "Activity feed")
|
|
2219
|
+
with compair.Session() as session:
|
|
2220
|
+
# Get user's groups
|
|
2221
|
+
|
|
2222
|
+
# Query recent activities related to user's groups
|
|
2223
|
+
group_ids = [g.group_id for g in current_user.groups]
|
|
2224
|
+
|
|
2225
|
+
q = (
|
|
2226
|
+
session.query(models.Activity)
|
|
2227
|
+
.filter(models.Activity.group_id.in_(group_ids))
|
|
2228
|
+
.order_by(models.Activity.timestamp.desc())
|
|
2229
|
+
)
|
|
2230
|
+
if not include_own_activities:
|
|
2231
|
+
q = q.filter(models.Activity.user_id != current_user.user_id)
|
|
2232
|
+
|
|
2233
|
+
total_count = q.count()
|
|
2234
|
+
offset = (page - 1) * page_size
|
|
2235
|
+
activities = q.offset(offset).limit(page_size).all()
|
|
2236
|
+
|
|
2237
|
+
return {
|
|
2238
|
+
"activities": [
|
|
2239
|
+
{
|
|
2240
|
+
"user": activity.user.name,
|
|
2241
|
+
"user_id": activity.user_id,
|
|
2242
|
+
"group_id": activity.group_id,
|
|
2243
|
+
"action": activity.action,
|
|
2244
|
+
"object": f"{activity.object_type} {activity.object_name}",
|
|
2245
|
+
"object_type": activity.object_type,
|
|
2246
|
+
"object_name": activity.object_name,
|
|
2247
|
+
"object_id": activity.object_id,
|
|
2248
|
+
"timestamp": activity.timestamp,
|
|
2249
|
+
"tooltip": get_feedback_tooltip(activity.object_id) if activity.action == "provided feedback" else None
|
|
2250
|
+
}
|
|
2251
|
+
for activity in activities
|
|
2252
|
+
],
|
|
2253
|
+
"total_count": total_count
|
|
2254
|
+
}
|
|
2255
|
+
|
|
2256
|
+
@router.delete("/delete_group")
|
|
2257
|
+
def delete_group(
|
|
2258
|
+
group_id: str,
|
|
2259
|
+
current_user: models.User = Depends(get_current_user)
|
|
2260
|
+
):
|
|
2261
|
+
"""Delete a group. Allowed only if the current user is an admin of the group.
|
|
2262
|
+
|
|
2263
|
+
This removes the group and its associations. Documents remain but lose their link to this group.
|
|
2264
|
+
"""
|
|
2265
|
+
with compair.Session() as session:
|
|
2266
|
+
group = session.query(models.Group).filter(models.Group.group_id == group_id).first()
|
|
2267
|
+
if not group:
|
|
2268
|
+
raise HTTPException(status_code=404, detail="Group not found")
|
|
2269
|
+
# Check admin rights
|
|
2270
|
+
is_admin = any(a.user_id == current_user.user_id for a in group.admins)
|
|
2271
|
+
if not is_admin:
|
|
2272
|
+
raise HTTPException(status_code=403, detail="Not authorized to delete this group")
|
|
2273
|
+
name = group.name
|
|
2274
|
+
session.delete(group)
|
|
2275
|
+
session.commit()
|
|
2276
|
+
# Log activity
|
|
2277
|
+
try:
|
|
2278
|
+
log_activity(
|
|
2279
|
+
session=session,
|
|
2280
|
+
user_id=current_user.user_id,
|
|
2281
|
+
group_id=group_id,
|
|
2282
|
+
action="delete",
|
|
2283
|
+
object_id=group_id,
|
|
2284
|
+
object_name=name,
|
|
2285
|
+
object_type="group",
|
|
2286
|
+
)
|
|
2287
|
+
except Exception:
|
|
2288
|
+
pass
|
|
2289
|
+
return {"message": "Group deleted", "group_id": group_id}
|
|
2290
|
+
|
|
2291
|
+
plan_to_id = {
|
|
2292
|
+
#"starter": "price_1Rpb5VEPPPghXLIJkgJyWBXb",
|
|
2293
|
+
#"standard": "price_1Qv7etEPPPghXLIJ5SR8KgXD",
|
|
2294
|
+
"individual_monthly": "price_1Ry34UEPPPghXLIJ4OLQNEna",#"price_1RwnRaEPPPghXLIJIPyeVxLq",
|
|
2295
|
+
"individual_annual": "price_1Ry34NEPPPghXLIJeftgEYOD",#"price_1RwoQkEPPPghXLIJyNUCRRo3",
|
|
2296
|
+
"team_monthly": "price_1Ry34HEPPPghXLIJoEyGRm7h",#"price_1RwoRDEPPPghXLIJt7uTCC7E",
|
|
2297
|
+
"team_annual": "price_1Ry34REPPPghXLIJXhLhwpBW"#"price_1RwoP9EPPPghXLIJo16OPgMu",
|
|
2298
|
+
}
|
|
2299
|
+
|
|
2300
|
+
@router.post("/create-checkout-session")
|
|
2301
|
+
def create_checkout_session(
|
|
2302
|
+
checkout_map: Mapping,
|
|
2303
|
+
current_user: models.User = Depends(get_current_user),
|
|
2304
|
+
billing: BillingProvider = Depends(get_billing),
|
|
2305
|
+
settings: Settings = Depends(get_settings_dependency),
|
|
2306
|
+
):
|
|
2307
|
+
"""Create a checkout session through the configured billing provider."""
|
|
2308
|
+
|
|
2309
|
+
require_cloud("Billing")
|
|
2310
|
+
|
|
2311
|
+
if not settings.billing_enabled:
|
|
2312
|
+
raise HTTPException(status_code=501, detail="Billing is not available in this edition.")
|
|
2313
|
+
|
|
2314
|
+
plan = checkout_map.get("plan", "individual_monthly")
|
|
2315
|
+
team_emails = checkout_map.get("team_emails", []) or []
|
|
2316
|
+
if isinstance(team_emails, str):
|
|
2317
|
+
team_emails = [email.strip() for email in team_emails.split(',') if email.strip()]
|
|
2318
|
+
if plan not in plan_to_id:
|
|
2319
|
+
raise HTTPException(status_code=400, detail="Unsupported plan selected.")
|
|
2320
|
+
|
|
2321
|
+
if plan.startswith("team") and len(team_emails) < 3:
|
|
2322
|
+
raise HTTPException(status_code=400, detail="Team plans require at least 4 emails.")
|
|
2323
|
+
|
|
2324
|
+
price_id = plan_to_id[plan]
|
|
2325
|
+
|
|
2326
|
+
quantity = 1 if plan.startswith("individual") else len(team_emails)
|
|
2327
|
+
|
|
2328
|
+
try:
|
|
2329
|
+
with compair.Session() as session:
|
|
2330
|
+
db_user = session.query(models.User).filter(models.User.user_id == current_user.user_id).first()
|
|
2331
|
+
if not db_user:
|
|
2332
|
+
raise HTTPException(status_code=404, detail="User not found.")
|
|
2333
|
+
|
|
2334
|
+
if not db_user.stripe_customer_id:
|
|
2335
|
+
log_event("stripe_customer_create_attempt", user_id=db_user.user_id)
|
|
2336
|
+
try:
|
|
2337
|
+
customer_id = billing.ensure_customer(
|
|
2338
|
+
user_email=db_user.username,
|
|
2339
|
+
user_id=db_user.user_id,
|
|
2340
|
+
)
|
|
2341
|
+
except NotImplementedError as exc:
|
|
2342
|
+
raise HTTPException(status_code=501, detail=str(exc)) from exc
|
|
2343
|
+
db_user.stripe_customer_id = customer_id
|
|
2344
|
+
session.commit()
|
|
2345
|
+
log_event("stripe_customer_created", user_id=db_user.user_id, customer_id=customer_id)
|
|
2346
|
+
|
|
2347
|
+
customer_id = db_user.stripe_customer_id
|
|
2348
|
+
|
|
2349
|
+
if plan.startswith("team"):
|
|
2350
|
+
require_feature(HAS_TEAM, "Team collaboration")
|
|
2351
|
+
team_name = checkout_map.get("team_name", f"{db_user.name}'s Team")
|
|
2352
|
+
team = models.Team(
|
|
2353
|
+
name=team_name,
|
|
2354
|
+
total_documents_limit=100 * len(team_emails),
|
|
2355
|
+
daily_feedback_limit=50 * len(team_emails),
|
|
2356
|
+
)
|
|
2357
|
+
session.add(team)
|
|
2358
|
+
session.commit()
|
|
2359
|
+
|
|
2360
|
+
for email in team_emails:
|
|
2361
|
+
existing_user = session.query(models.User).filter(models.User.username == email).first()
|
|
2362
|
+
if not existing_user:
|
|
2363
|
+
invitation = models.TeamInvitation(
|
|
2364
|
+
email=email,
|
|
2365
|
+
inviter_id=db_user.user_id,
|
|
2366
|
+
team_id=team.team_id,
|
|
2367
|
+
status="pending",
|
|
2368
|
+
datetime_created=datetime.now(timezone.utc),
|
|
2369
|
+
)
|
|
2370
|
+
session.add(invitation)
|
|
2371
|
+
session.commit()
|
|
2372
|
+
|
|
2373
|
+
log_event("stripe_checkout_session_start", user_id=db_user.user_id)
|
|
2374
|
+
|
|
2375
|
+
try:
|
|
2376
|
+
session_info = billing.create_checkout_session(
|
|
2377
|
+
customer_id=customer_id,
|
|
2378
|
+
price_id=price_id,
|
|
2379
|
+
qty=quantity,
|
|
2380
|
+
success_url=settings.stripe_success_url,
|
|
2381
|
+
cancel_url=settings.stripe_cancel_url,
|
|
2382
|
+
metadata={"plan": plan, "user_id": db_user.user_id},
|
|
2383
|
+
)
|
|
2384
|
+
except NotImplementedError as exc:
|
|
2385
|
+
raise HTTPException(status_code=501, detail=str(exc)) from exc
|
|
2386
|
+
except Exception as exc:
|
|
2387
|
+
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
|
2388
|
+
|
|
2389
|
+
session_id = getattr(session_info, "id", None)
|
|
2390
|
+
session_url = getattr(session_info, "url", None)
|
|
2391
|
+
if session_id is None and hasattr(session_info, "get"):
|
|
2392
|
+
session_id = session_info.get("id")
|
|
2393
|
+
if session_url is None and hasattr(session_info, "get"):
|
|
2394
|
+
session_url = session_info.get("url")
|
|
2395
|
+
if not session_id or not session_url:
|
|
2396
|
+
raise HTTPException(status_code=500, detail="Billing provider did not return a checkout URL.")
|
|
2397
|
+
|
|
2398
|
+
if hasattr(db_user, "checkout_session"):
|
|
2399
|
+
db_user.checkout_session = session_id
|
|
2400
|
+
if hasattr(db_user, "plan"):
|
|
2401
|
+
db_user.plan = plan
|
|
2402
|
+
session.commit()
|
|
2403
|
+
log_event("stripe_checkout_session_created", user_id=db_user.user_id, session_id=session_id)
|
|
2404
|
+
return {"url": session_url}
|
|
2405
|
+
except HTTPException:
|
|
2406
|
+
raise
|
|
2407
|
+
except Exception as exc:
|
|
2408
|
+
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
|
2409
|
+
|
|
2410
|
+
|
|
2411
|
+
@router.get("/redirect-to-checkout")
|
|
2412
|
+
def redirect_to_checkout(
|
|
2413
|
+
user_id: str,
|
|
2414
|
+
billing: BillingProvider = Depends(get_billing),
|
|
2415
|
+
settings: Settings = Depends(get_settings_dependency),
|
|
2416
|
+
):
|
|
2417
|
+
require_cloud("Billing")
|
|
2418
|
+
if not settings.billing_enabled:
|
|
2419
|
+
raise HTTPException(status_code=404, detail="Not found")
|
|
2420
|
+
|
|
2421
|
+
with compair.Session() as session:
|
|
2422
|
+
user = session.query(models.User).filter(models.User.user_id == user_id).first()
|
|
2423
|
+
if not user or not user.checkout_session:
|
|
2424
|
+
raise HTTPException(status_code=500, detail=f"No checkout session found for user {user_id}")
|
|
2425
|
+
|
|
2426
|
+
session_id = user.checkout_session
|
|
2427
|
+
try:
|
|
2428
|
+
checkout_url = billing.get_checkout_url(session_id)
|
|
2429
|
+
except NotImplementedError as exc:
|
|
2430
|
+
raise HTTPException(status_code=501, detail=str(exc)) from exc
|
|
2431
|
+
|
|
2432
|
+
user.checkout_session = None
|
|
2433
|
+
session.commit()
|
|
2434
|
+
|
|
2435
|
+
return {"url": checkout_url}
|
|
2436
|
+
|
|
2437
|
+
|
|
2438
|
+
|
|
2439
|
+
def track_payment_method(fingerprint: str):
|
|
2440
|
+
"""
|
|
2441
|
+
Track the number of credits associated with a payment fingerprint.
|
|
2442
|
+
"""
|
|
2443
|
+
with compair.Session() as session:
|
|
2444
|
+
total_credits = session.query(models.User).filter(
|
|
2445
|
+
models.User.payment_fingerprint == fingerprint,
|
|
2446
|
+
models.User.referral_credits > 0
|
|
2447
|
+
).count()
|
|
2448
|
+
|
|
2449
|
+
max_credits_per_payment = 3
|
|
2450
|
+
if total_credits >= max_credits_per_payment:
|
|
2451
|
+
raise HTTPException(
|
|
2452
|
+
status_code=400,
|
|
2453
|
+
detail="Referral credit limit reached for this payment method."
|
|
2454
|
+
)
|
|
2455
|
+
|
|
2456
|
+
def handle_successful_checkout(session):
|
|
2457
|
+
"""
|
|
2458
|
+
Capture the payment method during Stripe Checkout.
|
|
2459
|
+
"""
|
|
2460
|
+
print('made it to checkout event')
|
|
2461
|
+
print(session)
|
|
2462
|
+
customer_id = session["customer"]
|
|
2463
|
+
if customer_id is None:
|
|
2464
|
+
customer_id = session["customer_details"]["email"]
|
|
2465
|
+
|
|
2466
|
+
payment_method = session.get("payment_method") # Might not exist if using trial
|
|
2467
|
+
print(customer_id)
|
|
2468
|
+
print(payment_method)
|
|
2469
|
+
if payment_method:
|
|
2470
|
+
with compair.Session() as session:
|
|
2471
|
+
user = session.query(models.User).filter(
|
|
2472
|
+
models.User.stripe_customer_id == customer_id
|
|
2473
|
+
).first()
|
|
2474
|
+
print(user)
|
|
2475
|
+
if user:
|
|
2476
|
+
user.payment_fingerprint = payment_method
|
|
2477
|
+
track_payment_method(user.payment_fingerprint)
|
|
2478
|
+
session.commit()
|
|
2479
|
+
|
|
2480
|
+
payment_status = session["payment_status"]
|
|
2481
|
+
if payment_status != "paid":
|
|
2482
|
+
return
|
|
2483
|
+
|
|
2484
|
+
with compair.Session() as db_session:
|
|
2485
|
+
user = db_session.query(models.User).filter(
|
|
2486
|
+
models.User.stripe_customer_id == customer_id
|
|
2487
|
+
).first()
|
|
2488
|
+
if not user:
|
|
2489
|
+
return
|
|
2490
|
+
|
|
2491
|
+
if user.plan.startswith("team"):
|
|
2492
|
+
team = db_session.query(models.Team).filter(
|
|
2493
|
+
models.Team.ownership.has(owner_id=user.user_id)
|
|
2494
|
+
).first()
|
|
2495
|
+
if not team:
|
|
2496
|
+
return
|
|
2497
|
+
|
|
2498
|
+
# Update invitations with the new team ID and send emails
|
|
2499
|
+
invitations = db_session.query(models.TeamInvitation).filter(
|
|
2500
|
+
models.TeamInvitation.inviter_id == user.user_id,
|
|
2501
|
+
models.TeamInvitation.status == "pending",
|
|
2502
|
+
).all()
|
|
2503
|
+
for invitation in invitations:
|
|
2504
|
+
existing_user = db_session.query(models.User).filter(
|
|
2505
|
+
models.User.username == invitation.email
|
|
2506
|
+
).first()
|
|
2507
|
+
|
|
2508
|
+
if existing_user:
|
|
2509
|
+
# Send team invitation email to existing user
|
|
2510
|
+
emailer.connect()
|
|
2511
|
+
emailer.send(
|
|
2512
|
+
subject="You're Invited to Join a Team on Compair",
|
|
2513
|
+
sender=EMAIL_USER,
|
|
2514
|
+
receivers=[invitation.email],
|
|
2515
|
+
html=f"""
|
|
2516
|
+
<p>{user.name} has invited you to join their team on Compair!</p>
|
|
2517
|
+
<p>Click <a href="https://{WEB_URL}/accept-invitation?token={invitation.invitation_id}">here</a> to join.</p>
|
|
2518
|
+
"""
|
|
2519
|
+
)
|
|
2520
|
+
invitation.status = "sent"
|
|
2521
|
+
else:
|
|
2522
|
+
# Send invitation to join Compair
|
|
2523
|
+
send_invite(
|
|
2524
|
+
invitee_emails=invitation.email,
|
|
2525
|
+
current_user=user,
|
|
2526
|
+
)
|
|
2527
|
+
|
|
2528
|
+
def handle_successful_payment_intent(intent):
|
|
2529
|
+
"""
|
|
2530
|
+
Capture the payment method when the first invoice is paid.
|
|
2531
|
+
"""
|
|
2532
|
+
### Get / Store payment intent to customer_id (if both available!) if payment_method not available
|
|
2533
|
+
customer_id = intent["customer"]
|
|
2534
|
+
payment_method = intent.get("payment_method")
|
|
2535
|
+
print("Got to intent")
|
|
2536
|
+
print(intent)
|
|
2537
|
+
print(payment_method)
|
|
2538
|
+
if payment_method:
|
|
2539
|
+
with compair.Session() as session:
|
|
2540
|
+
user = session.query(models.User).filter(
|
|
2541
|
+
models.User.stripe_customer_id == customer_id
|
|
2542
|
+
).first()
|
|
2543
|
+
if user and hasattr(user, "payment_fingerprint"):
|
|
2544
|
+
user.payment_fingerprint = payment_method
|
|
2545
|
+
track_payment_method(user.payment_fingerprint)
|
|
2546
|
+
session.commit()
|
|
2547
|
+
|
|
2548
|
+
def handle_successful_invoice_payment(
|
|
2549
|
+
invoice,
|
|
2550
|
+
billing: BillingProvider,
|
|
2551
|
+
analytics: Analytics | None = None,
|
|
2552
|
+
):
|
|
2553
|
+
"""
|
|
2554
|
+
Capture the payment method when the first invoice is paid.
|
|
2555
|
+
"""
|
|
2556
|
+
### Use payment_intent to lookup customer (if latter is not available) to store payment_method
|
|
2557
|
+
customer_id = invoice["customer"]
|
|
2558
|
+
price_id = None
|
|
2559
|
+
for line in invoice["lines"]["data"]:
|
|
2560
|
+
if "price" in line and "id" in line["price"]:
|
|
2561
|
+
price_id = line["price"]["id"]
|
|
2562
|
+
break
|
|
2563
|
+
|
|
2564
|
+
plan = None
|
|
2565
|
+
for k, v in plan_to_id.items():
|
|
2566
|
+
if v == price_id:
|
|
2567
|
+
plan = k
|
|
2568
|
+
break
|
|
2569
|
+
|
|
2570
|
+
payment_method = invoice.get("payment_method")
|
|
2571
|
+
print("Got to payment")
|
|
2572
|
+
print(invoice)
|
|
2573
|
+
print(payment_method)
|
|
2574
|
+
if payment_method:
|
|
2575
|
+
with compair.Session() as session:
|
|
2576
|
+
user = session.query(models.User).filter(
|
|
2577
|
+
models.User.stripe_customer_id == customer_id
|
|
2578
|
+
).first()
|
|
2579
|
+
if user:
|
|
2580
|
+
user.payment_fingerprint = payment_method
|
|
2581
|
+
track_payment_method(user.payment_fingerprint)
|
|
2582
|
+
session.commit()
|
|
2583
|
+
|
|
2584
|
+
with compair.Session() as session:
|
|
2585
|
+
user = session.query(models.User).filter(
|
|
2586
|
+
models.User.stripe_customer_id == customer_id
|
|
2587
|
+
).first()
|
|
2588
|
+
print(f'invoice user: {user}')
|
|
2589
|
+
if user:
|
|
2590
|
+
# user.stripe_subscription_id = subscription_id ## Can track if needed
|
|
2591
|
+
user.status = 'active' # Mark the user as subscribed
|
|
2592
|
+
if hasattr(user, "last_payment_date"):
|
|
2593
|
+
user.last_payment_date = datetime.now(timezone.utc)
|
|
2594
|
+
user.status_change_date = datetime.now(timezone.utc)
|
|
2595
|
+
if plan and hasattr(user, "plan"):
|
|
2596
|
+
user.plan = plan
|
|
2597
|
+
print(user.status)
|
|
2598
|
+
# Handle referrer logic
|
|
2599
|
+
if hasattr(user, "referred_by") and user.referred_by:
|
|
2600
|
+
print(user.referred_by)
|
|
2601
|
+
referrer = session.query(models.User).filter(
|
|
2602
|
+
models.User.user_id == user.referred_by
|
|
2603
|
+
).first()
|
|
2604
|
+
print(referrer)
|
|
2605
|
+
if referrer and hasattr(referrer, "pending_referral_credits") and referrer.pending_referral_credits > 0:
|
|
2606
|
+
# Convert pending credits to earned credits
|
|
2607
|
+
if hasattr(referrer, "referral_credits"):
|
|
2608
|
+
referrer.referral_credits += 10
|
|
2609
|
+
referrer.pending_referral_credits -= 10
|
|
2610
|
+
# Send email notification
|
|
2611
|
+
emailer.connect()
|
|
2612
|
+
emailer.send(
|
|
2613
|
+
subject="You've Earned a Free Month!",
|
|
2614
|
+
sender=EMAIL_USER,
|
|
2615
|
+
receivers=[referrer.username],
|
|
2616
|
+
html=REFERRAL_CREDIT_TEMPLATE.replace(
|
|
2617
|
+
"{{user_name}}", referrer.name
|
|
2618
|
+
).replace(
|
|
2619
|
+
"{{referral_credits}}", str(referrer.referral_credits),
|
|
2620
|
+
)
|
|
2621
|
+
)
|
|
2622
|
+
# Create and apply a $15 coupon in Stripe ($25 if team plan)
|
|
2623
|
+
amount = 15
|
|
2624
|
+
if referrer.plan == 'team':
|
|
2625
|
+
amount=25
|
|
2626
|
+
try:
|
|
2627
|
+
coupon_id = billing.create_coupon(amount)
|
|
2628
|
+
print(f'Coupon ID 2:{coupon_id}')
|
|
2629
|
+
if referrer.stripe_customer_id:
|
|
2630
|
+
billing.apply_coupon(customer_id=referrer.stripe_customer_id, coupon_id=coupon_id)
|
|
2631
|
+
except NotImplementedError:
|
|
2632
|
+
pass
|
|
2633
|
+
|
|
2634
|
+
if analytics:
|
|
2635
|
+
try:
|
|
2636
|
+
analytics.track("subscription_created", user.user_id)
|
|
2637
|
+
except Exception as exc:
|
|
2638
|
+
print(f"analytics track failed: {exc}")
|
|
2639
|
+
|
|
2640
|
+
session.commit()
|
|
2641
|
+
|
|
2642
|
+
def handle_subscription_cancellation(subscription):
|
|
2643
|
+
"""
|
|
2644
|
+
Revoke access if a subscription is canceled.
|
|
2645
|
+
"""
|
|
2646
|
+
customer_id = subscription["customer"]
|
|
2647
|
+
print('Cancelled')
|
|
2648
|
+
print(subscription)
|
|
2649
|
+
with compair.Session() as session:
|
|
2650
|
+
user = session.query(models.User).filter(
|
|
2651
|
+
models.User.stripe_customer_id == customer_id
|
|
2652
|
+
).first()
|
|
2653
|
+
|
|
2654
|
+
if user:
|
|
2655
|
+
user.status = 'suspended' # Remove access
|
|
2656
|
+
user.status_change_date = datetime.now(timezone.utc)
|
|
2657
|
+
session.commit()
|
|
2658
|
+
|
|
2659
|
+
def handle_failed_payment(invoice):
|
|
2660
|
+
"""
|
|
2661
|
+
Handle failed payment attempts.
|
|
2662
|
+
"""
|
|
2663
|
+
customer_id = invoice["customer"]
|
|
2664
|
+
print('Failure')
|
|
2665
|
+
print(invoice)
|
|
2666
|
+
with compair.Session() as session:
|
|
2667
|
+
user = session.query(models.User).filter(
|
|
2668
|
+
models.User.stripe_customer_id == customer_id
|
|
2669
|
+
).first()
|
|
2670
|
+
|
|
2671
|
+
if user:
|
|
2672
|
+
user.status = 'suspended' # Suspend access until payment is made
|
|
2673
|
+
user.status_change_date = datetime.now(timezone.utc)
|
|
2674
|
+
session.commit()
|
|
2675
|
+
|
|
2676
|
+
@router.post("/webhook")
|
|
2677
|
+
async def stripe_webhook(
|
|
2678
|
+
request: Request,
|
|
2679
|
+
billing: BillingProvider = Depends(get_billing),
|
|
2680
|
+
analytics: Analytics = Depends(get_analytics),
|
|
2681
|
+
settings: Settings = Depends(get_settings_dependency),
|
|
2682
|
+
):
|
|
2683
|
+
if not settings.billing_enabled:
|
|
2684
|
+
raise HTTPException(status_code=404, detail="Not found")
|
|
2685
|
+
|
|
2686
|
+
payload = await request.body()
|
|
2687
|
+
sig_header = request.headers.get("stripe-signature")
|
|
2688
|
+
#print(f'payload: {payload}')
|
|
2689
|
+
try:
|
|
2690
|
+
event = billing.construct_event(payload, sig_header)
|
|
2691
|
+
print(f'webhook event: {event["type"]}')
|
|
2692
|
+
|
|
2693
|
+
log_event("stripe_webhook_start", event_type=event['type'], data=event['data']['object'])
|
|
2694
|
+
|
|
2695
|
+
if event["type"] == "payment_intent.succeeded":
|
|
2696
|
+
intent = event["data"]["object"]
|
|
2697
|
+
handle_successful_payment_intent(intent)
|
|
2698
|
+
|
|
2699
|
+
elif event["type"] == "invoice.payment_succeeded":
|
|
2700
|
+
invoice = event["data"]["object"]
|
|
2701
|
+
handle_successful_invoice_payment(invoice, billing, analytics)
|
|
2702
|
+
|
|
2703
|
+
elif event["type"] == "customer.subscription.deleted":
|
|
2704
|
+
subscription = event["data"]["object"]
|
|
2705
|
+
handle_subscription_cancellation(subscription)
|
|
2706
|
+
|
|
2707
|
+
elif event["type"] == "invoice.payment_failed":
|
|
2708
|
+
invoice = event["data"]["object"]
|
|
2709
|
+
handle_failed_payment(invoice)
|
|
2710
|
+
|
|
2711
|
+
log_event("stripe_webhook_complete", event_type=event['type'], data=event['data']['object'])
|
|
2712
|
+
|
|
2713
|
+
return {"status": "success"}
|
|
2714
|
+
except Exception as e:
|
|
2715
|
+
raise HTTPException(status_code=400, detail=str(e))
|
|
2716
|
+
|
|
2717
|
+
|
|
2718
|
+
|
|
2719
|
+
def generate_referral_link(referral_code: str) -> str:
|
|
2720
|
+
"""
|
|
2721
|
+
Generate a referral link using the user's referral code.
|
|
2722
|
+
"""
|
|
2723
|
+
require_cloud("Referral program")
|
|
2724
|
+
base_url = f"http://{WEB_URL}/login"
|
|
2725
|
+
return f"{base_url}?ref={referral_code}"
|
|
2726
|
+
|
|
2727
|
+
@router.post("/send-invite")
|
|
2728
|
+
def send_invite(
|
|
2729
|
+
invitee_emails: str = Form(...),
|
|
2730
|
+
group_id: str = Form(None),
|
|
2731
|
+
current_user: models.User = Depends(get_current_user),
|
|
2732
|
+
):
|
|
2733
|
+
"""
|
|
2734
|
+
Send an email invitation with the user's referral code.
|
|
2735
|
+
If group_id is provided, track the invitation for group auto-invite on sign-up.
|
|
2736
|
+
"""
|
|
2737
|
+
require_cloud("Referral program")
|
|
2738
|
+
invitee_emails = invitee_emails.split(',')
|
|
2739
|
+
with compair.Session() as session:
|
|
2740
|
+
|
|
2741
|
+
referral_link = generate_referral_link(current_user.referral_code)
|
|
2742
|
+
|
|
2743
|
+
# Send email notification
|
|
2744
|
+
emailer.connect()
|
|
2745
|
+
for invitee_email in invitee_emails:
|
|
2746
|
+
# Track pending group invitation for this email if group_id is provided
|
|
2747
|
+
if group_id:
|
|
2748
|
+
# Store a "pending" group invitation for this email
|
|
2749
|
+
token = secrets.token_urlsafe(32).lower()
|
|
2750
|
+
invitation = models.GroupInvitation(
|
|
2751
|
+
group_id=group_id,
|
|
2752
|
+
inviter_id=current_user.user_id,
|
|
2753
|
+
token=token,
|
|
2754
|
+
email=invitee_email,
|
|
2755
|
+
datetime_expiration=datetime.utcnow() + timedelta(days=7),
|
|
2756
|
+
status="pending"
|
|
2757
|
+
)
|
|
2758
|
+
session.add(invitation)
|
|
2759
|
+
session.commit()
|
|
2760
|
+
# Send email notification
|
|
2761
|
+
emailer.send(
|
|
2762
|
+
subject="You're Invited to Compair!",
|
|
2763
|
+
sender=EMAIL_USER,
|
|
2764
|
+
receivers=[invitee_email],
|
|
2765
|
+
html=INDIVIDUAL_INVITATION_TEMPLATE.replace(
|
|
2766
|
+
"{{inviter_name}}", current_user.name
|
|
2767
|
+
).replace(
|
|
2768
|
+
"{{referral_link}}", referral_link,
|
|
2769
|
+
)
|
|
2770
|
+
)
|
|
2771
|
+
|
|
2772
|
+
return {"message": "Invitation sent successfully"}
|
|
2773
|
+
|
|
2774
|
+
@router.post("/get-customer-portal")
|
|
2775
|
+
def get_customer_portal(
|
|
2776
|
+
current_user: models.User = Depends(get_current_user),
|
|
2777
|
+
billing: BillingProvider = Depends(get_billing),
|
|
2778
|
+
settings: Settings = Depends(get_settings_dependency),
|
|
2779
|
+
):
|
|
2780
|
+
"""Generate a billing portal session link through the billing provider."""
|
|
2781
|
+
|
|
2782
|
+
if not settings.billing_enabled:
|
|
2783
|
+
raise HTTPException(status_code=501, detail="Billing is not available in this edition.")
|
|
2784
|
+
|
|
2785
|
+
return_url = f"http://{WEB_URL}/home" if WEB_URL else "https://compair.sh/home"
|
|
2786
|
+
|
|
2787
|
+
try:
|
|
2788
|
+
with compair.Session() as session:
|
|
2789
|
+
user = session.query(models.User).filter(models.User.user_id == current_user.user_id).first()
|
|
2790
|
+
if not user or not user.stripe_customer_id:
|
|
2791
|
+
raise HTTPException(status_code=400, detail="No billing profile found for this user.")
|
|
2792
|
+
|
|
2793
|
+
try:
|
|
2794
|
+
portal_url = billing.create_customer_portal(
|
|
2795
|
+
customer_id=user.stripe_customer_id,
|
|
2796
|
+
return_url=return_url,
|
|
2797
|
+
)
|
|
2798
|
+
except NotImplementedError as exc:
|
|
2799
|
+
raise HTTPException(status_code=501, detail=str(exc)) from exc
|
|
2800
|
+
|
|
2801
|
+
return {"url": portal_url}
|
|
2802
|
+
except HTTPException:
|
|
2803
|
+
raise
|
|
2804
|
+
except Exception as exc:
|
|
2805
|
+
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
|
2806
|
+
|
|
2807
|
+
|
|
2808
|
+
|
|
2809
|
+
|
|
2810
|
+
@router.post("/upload/profile")
|
|
2811
|
+
async def upload_profile_image(
|
|
2812
|
+
upload_type: str = Form(...),
|
|
2813
|
+
file: UploadFile = File(...),
|
|
2814
|
+
current_user: models.User = Depends(get_current_user)
|
|
2815
|
+
) -> Mapping[str, str]:
|
|
2816
|
+
if upload_type!='profile':
|
|
2817
|
+
raise HTTPException(status_code=400, detail="Invalid upload type")
|
|
2818
|
+
|
|
2819
|
+
allowed_types = ["image/jpg", "image/jpeg", "image/png", "image/webp"]
|
|
2820
|
+
if file.content_type not in allowed_types:
|
|
2821
|
+
raise HTTPException(status_code=400, detail="Invalid file type")
|
|
2822
|
+
|
|
2823
|
+
# Upload to Cloudflare Images
|
|
2824
|
+
headers = {
|
|
2825
|
+
"Authorization": f"Bearer {CLOUDFLARE_IMAGES_TOKEN}"
|
|
2826
|
+
}
|
|
2827
|
+
data = {
|
|
2828
|
+
"requireSignedURLs": "false"
|
|
2829
|
+
}
|
|
2830
|
+
files = {
|
|
2831
|
+
"file": (file.filename, await file.read(), file.content_type)
|
|
2832
|
+
}
|
|
2833
|
+
|
|
2834
|
+
async with httpx.AsyncClient() as client:
|
|
2835
|
+
response = await client.post(
|
|
2836
|
+
CLOUDFLARE_IMAGES_UPLOAD_URL,
|
|
2837
|
+
headers=headers,
|
|
2838
|
+
data=data,
|
|
2839
|
+
files=files,
|
|
2840
|
+
timeout=30
|
|
2841
|
+
)
|
|
2842
|
+
if response.status_code != 200:
|
|
2843
|
+
raise HTTPException(status_code=500, detail="Cloudflare Images upload failed")
|
|
2844
|
+
|
|
2845
|
+
result = response.json()
|
|
2846
|
+
if not result.get("success"):
|
|
2847
|
+
raise HTTPException(status_code=500, detail="Cloudflare Images upload failed")
|
|
2848
|
+
|
|
2849
|
+
image_id = result["result"]["id"]
|
|
2850
|
+
|
|
2851
|
+
with compair.Session() as session:
|
|
2852
|
+
user = session.query(models.User).filter(models.User.user_id == current_user.user_id).first()
|
|
2853
|
+
user.profile_image = image_id
|
|
2854
|
+
session.commit()
|
|
2855
|
+
|
|
2856
|
+
# Read file bytes to get size
|
|
2857
|
+
file_bytes = await file.read()
|
|
2858
|
+
file_size = len(file_bytes)
|
|
2859
|
+
|
|
2860
|
+
# Log the upload event
|
|
2861
|
+
log_event(
|
|
2862
|
+
"cloudflare_upload",
|
|
2863
|
+
upload_type="profile_image",
|
|
2864
|
+
user_id=current_user.user_id,
|
|
2865
|
+
file_size=file_size,
|
|
2866
|
+
file_key=image_id,
|
|
2867
|
+
content_type=file.content_type
|
|
2868
|
+
)
|
|
2869
|
+
|
|
2870
|
+
return {"image_id": image_id}
|
|
2871
|
+
|
|
2872
|
+
|
|
2873
|
+
@router.get("/get_profile_image")
|
|
2874
|
+
def get_profile_image(
|
|
2875
|
+
variant: str = Query("public", enum=["public", "avatar", "preview"]),
|
|
2876
|
+
current_user: models.User = Depends(get_current_user)
|
|
2877
|
+
) -> Mapping[str, str]:
|
|
2878
|
+
with compair.Session() as session:
|
|
2879
|
+
image_id = current_user.profile_image
|
|
2880
|
+
if not image_id:
|
|
2881
|
+
raise HTTPException(status_code=400, detail="No profile image found")
|
|
2882
|
+
image_url = f"{CLOUDFLARE_IMAGES_BASE_URL}/{image_id}/{variant}"
|
|
2883
|
+
return {"url": image_url}
|
|
2884
|
+
|
|
2885
|
+
|
|
2886
|
+
@router.post("/upload/file")
|
|
2887
|
+
async def upload_file(
|
|
2888
|
+
document_id: str = Form(...),
|
|
2889
|
+
upload_type: str = Form(...),
|
|
2890
|
+
file: UploadFile = File(...),
|
|
2891
|
+
current_user: models.User = Depends(get_current_user),
|
|
2892
|
+
storage: StorageProvider = Depends(get_storage),
|
|
2893
|
+
) -> Mapping[str, str]:
|
|
2894
|
+
if upload_type != 'file':
|
|
2895
|
+
raise HTTPException(status_code=400, detail="Invalid upload type")
|
|
2896
|
+
|
|
2897
|
+
allowed_types = ["application/pdf"]#, "document/docx"]
|
|
2898
|
+
if file.content_type not in allowed_types:
|
|
2899
|
+
raise HTTPException(status_code=400, detail="Invalid file type")
|
|
2900
|
+
|
|
2901
|
+
file_name = file.filename
|
|
2902
|
+
file_name_hash = hashlib.sha256(f"{file_name}".encode()).hexdigest()
|
|
2903
|
+
file_key = f"{upload_type}/{current_user.user_id}/{file_name_hash}/{document_id}"
|
|
2904
|
+
|
|
2905
|
+
# Read file bytes to get size and reset file pointer before upload
|
|
2906
|
+
file_bytes = await file.read()
|
|
2907
|
+
file_size = len(file_bytes)
|
|
2908
|
+
await file.seek(0)
|
|
2909
|
+
|
|
2910
|
+
content_type = file.content_type or "application/octet-stream"
|
|
2911
|
+
|
|
2912
|
+
try:
|
|
2913
|
+
upload_url = storage.put_file(file_key, file.file, content_type)
|
|
2914
|
+
except NotImplementedError as exc:
|
|
2915
|
+
raise HTTPException(status_code=501, detail=str(exc)) from exc
|
|
2916
|
+
except Exception as exc: # pragma: no cover - surfaced to client
|
|
2917
|
+
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
|
2918
|
+
|
|
2919
|
+
with compair.Session() as session:
|
|
2920
|
+
# Fetch the document
|
|
2921
|
+
document = session.query(models.Document).filter(models.Document.document_id == document_id).first()
|
|
2922
|
+
if not document:
|
|
2923
|
+
raise HTTPException(status_code=404, detail="Document not found")
|
|
2924
|
+
|
|
2925
|
+
document.file_key = file_key
|
|
2926
|
+
session.commit()
|
|
2927
|
+
|
|
2928
|
+
# Log the upload event
|
|
2929
|
+
log_event(
|
|
2930
|
+
"cloudflare_upload",
|
|
2931
|
+
upload_type="document_file",
|
|
2932
|
+
user_id=current_user.user_id,
|
|
2933
|
+
document_id=document_id,
|
|
2934
|
+
file_size=file_size,
|
|
2935
|
+
file_key=file_key,
|
|
2936
|
+
content_type=content_type
|
|
2937
|
+
)
|
|
2938
|
+
|
|
2939
|
+
return {"url": upload_url}
|
|
2940
|
+
|
|
2941
|
+
|
|
2942
|
+
@router.post("/upload/group")
|
|
2943
|
+
async def upload_group_image(
|
|
2944
|
+
group_id: str,
|
|
2945
|
+
upload_type: str,
|
|
2946
|
+
file: UploadFile = File(...)
|
|
2947
|
+
) -> Mapping[str, str]:
|
|
2948
|
+
if upload_type!='group':
|
|
2949
|
+
raise HTTPException(status_code=400, detail="Invalid upload type")
|
|
2950
|
+
|
|
2951
|
+
allowed_types = ["image/jpg", "image/jpeg", "image/png", "image/webp"]
|
|
2952
|
+
if file.content_type not in allowed_types:
|
|
2953
|
+
raise HTTPException(status_code=400, detail="Invalid file type")
|
|
2954
|
+
|
|
2955
|
+
# Upload to Cloudflare Images
|
|
2956
|
+
headers = {
|
|
2957
|
+
"Authorization": f"Bearer {CLOUDFLARE_IMAGES_TOKEN}"
|
|
2958
|
+
}
|
|
2959
|
+
data = {
|
|
2960
|
+
"requireSignedURLs": "false"
|
|
2961
|
+
}
|
|
2962
|
+
files = {
|
|
2963
|
+
"file": (file.filename, await file.read(), file.content_type)
|
|
2964
|
+
}
|
|
2965
|
+
|
|
2966
|
+
async with httpx.AsyncClient() as client:
|
|
2967
|
+
response = await client.post(
|
|
2968
|
+
CLOUDFLARE_IMAGES_UPLOAD_URL,
|
|
2969
|
+
headers=headers,
|
|
2970
|
+
data=data,
|
|
2971
|
+
files=files,
|
|
2972
|
+
timeout=30
|
|
2973
|
+
)
|
|
2974
|
+
if response.status_code != 200:
|
|
2975
|
+
raise HTTPException(status_code=500, detail="Cloudflare Images upload failed")
|
|
2976
|
+
|
|
2977
|
+
result = response.json()
|
|
2978
|
+
if not result.get("success"):
|
|
2979
|
+
raise HTTPException(status_code=500, detail="Cloudflare Images upload failed")
|
|
2980
|
+
|
|
2981
|
+
image_id = result["result"]["id"]
|
|
2982
|
+
|
|
2983
|
+
with compair.Session() as session:
|
|
2984
|
+
group = session.query(models.Group).filter(models.Group.group_id == group_id).first()
|
|
2985
|
+
group.group_image = image_id
|
|
2986
|
+
session.commit()
|
|
2987
|
+
|
|
2988
|
+
# Read file bytes to get size
|
|
2989
|
+
file_bytes = await file.read()
|
|
2990
|
+
file_size = len(file_bytes)
|
|
2991
|
+
|
|
2992
|
+
# Log the upload event
|
|
2993
|
+
log_event(
|
|
2994
|
+
"cloudflare_upload",
|
|
2995
|
+
upload_type="group_image",
|
|
2996
|
+
group_id=group_id,
|
|
2997
|
+
file_size=file_size,
|
|
2998
|
+
file_key=image_id,
|
|
2999
|
+
content_type=file.content_type
|
|
3000
|
+
)
|
|
3001
|
+
|
|
3002
|
+
@router.post("/send-feature-announcement")
|
|
3003
|
+
def send_feature_announcement(admin_key: str = Header(None)):
|
|
3004
|
+
"""
|
|
3005
|
+
Manually trigger a feature announcement email to churned users.
|
|
3006
|
+
Requires an admin API key.
|
|
3007
|
+
"""
|
|
3008
|
+
if not IS_CLOUD or send_feature_announcement_task is None:
|
|
3009
|
+
raise HTTPException(status_code=501, detail="Feature announcements require the Compair Cloud edition.")
|
|
3010
|
+
if admin_key != ADMIN_API_KEY:
|
|
3011
|
+
raise HTTPException(status_code=403, detail="Unauthorized")
|
|
3012
|
+
|
|
3013
|
+
send_feature_announcement_task.delay() # Run email sending as an async Celery task
|
|
3014
|
+
|
|
3015
|
+
return {"message": "Feature announcement email campaign triggered successfully."}
|
|
3016
|
+
|
|
3017
|
+
|
|
3018
|
+
@router.post("/retrial/{user_id}")
|
|
3019
|
+
def grant_retrial(user_id: str):
|
|
3020
|
+
"""
|
|
3021
|
+
Reactivate a churned user's trial for 2 weeks.
|
|
3022
|
+
"""
|
|
3023
|
+
if not IS_CLOUD:
|
|
3024
|
+
raise HTTPException(status_code=501, detail="Re-trials require the Compair Cloud edition.")
|
|
3025
|
+
with Session() as session:
|
|
3026
|
+
user = session.query(models.User).filter(models.User.user_id == user_id).first()
|
|
3027
|
+
|
|
3028
|
+
if not user or user.status != "suspended":
|
|
3029
|
+
raise HTTPException(status_code=400, detail="User not eligible for a re-trial.")
|
|
3030
|
+
|
|
3031
|
+
if user.retrial_count >= 1:
|
|
3032
|
+
raise HTTPException(status_code=403, detail="Re-trial limit reached.")
|
|
3033
|
+
|
|
3034
|
+
# Reactivate trial
|
|
3035
|
+
user.status = "trial"
|
|
3036
|
+
user.status_change_date = datetime.now(timezone.utc)
|
|
3037
|
+
user.trial_expiration_date = datetime.now(timezone.utc) + timedelta(days=14)
|
|
3038
|
+
user.retrial_count += 1 # Increment re-trial count
|
|
3039
|
+
user.last_retrial_date = datetime.now(timezone.utc)
|
|
3040
|
+
session.commit()
|
|
3041
|
+
|
|
3042
|
+
return {"message": "Your re-trial has been activated!"}
|
|
3043
|
+
|
|
3044
|
+
@router.post("/documents/{document_id}/notes")
|
|
3045
|
+
def create_note(
|
|
3046
|
+
document_id: str,
|
|
3047
|
+
content: str = Form(...),
|
|
3048
|
+
group_id: str = Form(None),
|
|
3049
|
+
current_user: models.User = Depends(get_current_user)
|
|
3050
|
+
):
|
|
3051
|
+
"""Create a Note for a Document. User must be in the Document's group. Also chunk/embed Note content."""
|
|
3052
|
+
with compair.Session() as session:
|
|
3053
|
+
document = session.query(models.Document).filter(models.Document.document_id == document_id).first()
|
|
3054
|
+
if not document:
|
|
3055
|
+
raise HTTPException(status_code=404, detail="Document not found")
|
|
3056
|
+
# Check group membership
|
|
3057
|
+
doc_group_ids = [g.group_id for g in document.groups]
|
|
3058
|
+
user_group_ids = [g.group_id for g in current_user.groups]
|
|
3059
|
+
if not set(doc_group_ids) & set(user_group_ids):
|
|
3060
|
+
raise HTTPException(status_code=403, detail="User not in document's group")
|
|
3061
|
+
note = models.Note(
|
|
3062
|
+
document_id=document_id,
|
|
3063
|
+
author_id=current_user.user_id,
|
|
3064
|
+
group_id=group_id or (doc_group_ids[0] if doc_group_ids else None),
|
|
3065
|
+
content=content,
|
|
3066
|
+
datetime_created=datetime.now(timezone.utc)
|
|
3067
|
+
)
|
|
3068
|
+
session.add(note)
|
|
3069
|
+
session.commit()
|
|
3070
|
+
session.refresh(note)
|
|
3071
|
+
# Chunk and embed Note content
|
|
3072
|
+
note_text_chunks = chunk_text(content)
|
|
3073
|
+
embedder = Embedder()
|
|
3074
|
+
for text in note_text_chunks:
|
|
3075
|
+
chunk_hash = hash(text)
|
|
3076
|
+
embedding = create_embedding(embedder, text, user=current_user)
|
|
3077
|
+
note_chunk = models.Chunk(
|
|
3078
|
+
hash=str(chunk_hash),
|
|
3079
|
+
document_id=document_id,
|
|
3080
|
+
note_id=note.note_id,
|
|
3081
|
+
chunk_type="note",
|
|
3082
|
+
content=text,
|
|
3083
|
+
)
|
|
3084
|
+
note_chunk.embedding = embedding
|
|
3085
|
+
session.add(note_chunk)
|
|
3086
|
+
session.commit()
|
|
3087
|
+
|
|
3088
|
+
log_activity(
|
|
3089
|
+
session=session,
|
|
3090
|
+
user_id=current_user.user_id,
|
|
3091
|
+
group_id=note.group_id,
|
|
3092
|
+
action="create",
|
|
3093
|
+
object_id=note.note_id,
|
|
3094
|
+
object_name=f"Note on {document.title}",
|
|
3095
|
+
object_type="note"
|
|
3096
|
+
)
|
|
3097
|
+
|
|
3098
|
+
return schema.Note(
|
|
3099
|
+
note_id=note.note_id,
|
|
3100
|
+
document_id=note.document_id,
|
|
3101
|
+
author_id=note.author_id,
|
|
3102
|
+
group_id=note.group_id,
|
|
3103
|
+
content=note.content,
|
|
3104
|
+
datetime_created=note.datetime_created,
|
|
3105
|
+
author=schema.User(
|
|
3106
|
+
user_id=current_user.user_id,
|
|
3107
|
+
username=current_user.username,
|
|
3108
|
+
name=current_user.name,
|
|
3109
|
+
datetime_registered=current_user.datetime_registered,
|
|
3110
|
+
status=current_user.status,
|
|
3111
|
+
profile_image=current_user.profile_image,
|
|
3112
|
+
role=current_user.role,
|
|
3113
|
+
)
|
|
3114
|
+
)
|
|
3115
|
+
|
|
3116
|
+
@router.get("/documents/{document_id}/notes")
|
|
3117
|
+
def list_notes(
|
|
3118
|
+
document_id: str,
|
|
3119
|
+
current_user: models.User = Depends(get_current_user)
|
|
3120
|
+
) -> list[schema.Note]:
|
|
3121
|
+
"""List all Notes for a Document."""
|
|
3122
|
+
with compair.Session() as session:
|
|
3123
|
+
notes = session.query(models.Note).filter(models.Note.document_id == document_id).all()
|
|
3124
|
+
result = []
|
|
3125
|
+
for note in notes:
|
|
3126
|
+
author = session.query(models.User).filter(models.User.user_id == note.author_id).first()
|
|
3127
|
+
result.append(schema.Note(
|
|
3128
|
+
note_id=note.note_id,
|
|
3129
|
+
document_id=note.document_id,
|
|
3130
|
+
author_id=note.author_id,
|
|
3131
|
+
group_id=note.group_id,
|
|
3132
|
+
content=note.content,
|
|
3133
|
+
datetime_created=note.datetime_created,
|
|
3134
|
+
author=schema.User(
|
|
3135
|
+
user_id=author.user_id,
|
|
3136
|
+
username=author.username,
|
|
3137
|
+
name=author.name,
|
|
3138
|
+
datetime_registered=author.datetime_registered,
|
|
3139
|
+
status=author.status,
|
|
3140
|
+
profile_image=author.profile_image,
|
|
3141
|
+
role=author.role,
|
|
3142
|
+
) if author else None
|
|
3143
|
+
))
|
|
3144
|
+
return result
|
|
3145
|
+
|
|
3146
|
+
@router.get("/notes/{note_id}")
|
|
3147
|
+
def get_note(
|
|
3148
|
+
note_id: str,
|
|
3149
|
+
current_user: models.User = Depends(get_current_user)
|
|
3150
|
+
) -> schema.Note:
|
|
3151
|
+
"""Get a single Note by ID."""
|
|
3152
|
+
with compair.Session() as session:
|
|
3153
|
+
note = session.query(models.Note).filter(models.Note.note_id == note_id).first()
|
|
3154
|
+
if not note:
|
|
3155
|
+
raise HTTPException(status_code=404, detail="Note not found")
|
|
3156
|
+
author = session.query(models.User).filter(models.User.user_id == note.author_id).first()
|
|
3157
|
+
return schema.Note(
|
|
3158
|
+
note_id=note.note_id,
|
|
3159
|
+
document_id=note.document_id,
|
|
3160
|
+
author_id=note.author_id,
|
|
3161
|
+
group_id=note.group_id,
|
|
3162
|
+
content=note.content,
|
|
3163
|
+
datetime_created=note.datetime_created,
|
|
3164
|
+
author=schema.User(
|
|
3165
|
+
user_id=author.user_id,
|
|
3166
|
+
username=author.username,
|
|
3167
|
+
name=author.name,
|
|
3168
|
+
datetime_registered=author.datetime_registered,
|
|
3169
|
+
status=author.status,
|
|
3170
|
+
profile_image=author.profile_image,
|
|
3171
|
+
role=author.role,
|
|
3172
|
+
) if author else None
|
|
3173
|
+
)
|
|
3174
|
+
|
|
3175
|
+
@router.get("/documents/{document_id}/file-url")
|
|
3176
|
+
def get_document_file_url(
|
|
3177
|
+
document_id: str,
|
|
3178
|
+
current_user: models.User = Depends(get_current_user),
|
|
3179
|
+
storage: StorageProvider = Depends(get_storage),
|
|
3180
|
+
):
|
|
3181
|
+
"""Return the file URL for a Document."""
|
|
3182
|
+
with compair.Session() as session:
|
|
3183
|
+
doc = session.query(models.Document).filter(models.Document.document_id == document_id).first()
|
|
3184
|
+
if not doc or not doc.file_key:
|
|
3185
|
+
raise HTTPException(status_code=404, detail="File not found for this document.")
|
|
3186
|
+
file_url = storage.build_url(doc.file_key)
|
|
3187
|
+
return {"file_url": file_url}
|
|
3188
|
+
|
|
3189
|
+
@router.get("/documents/{document_id}/image-url")
|
|
3190
|
+
def get_document_image_url(
|
|
3191
|
+
document_id: str,
|
|
3192
|
+
current_user: models.User = Depends(get_current_user),
|
|
3193
|
+
storage: StorageProvider = Depends(get_storage),
|
|
3194
|
+
):
|
|
3195
|
+
"""Return the preview image URL for a Document."""
|
|
3196
|
+
with compair.Session() as session:
|
|
3197
|
+
doc = session.query(models.Document).filter(models.Document.document_id == document_id).first()
|
|
3198
|
+
if not doc or not doc.image_key:
|
|
3199
|
+
raise HTTPException(status_code=404, detail="Image not found for this document.")
|
|
3200
|
+
image_url = storage.build_url(doc.image_key)
|
|
3201
|
+
return {"image_url": image_url}
|
|
3202
|
+
|
|
3203
|
+
|
|
3204
|
+
def sanitize_filename(filename: str) -> str:
|
|
3205
|
+
# Replace non-ASCII characters with underscore
|
|
3206
|
+
return re.sub(r'[^\x00-\x7F]+', '_', filename)
|
|
3207
|
+
|
|
3208
|
+
|
|
3209
|
+
@router.post("/documents/{document_id}/generate-download-token")
|
|
3210
|
+
def generate_download_token(
|
|
3211
|
+
document_id: str,
|
|
3212
|
+
current_user: models.User = Depends(get_current_user)
|
|
3213
|
+
):
|
|
3214
|
+
# Check permissions as in your download endpoint
|
|
3215
|
+
with compair.Session() as session:
|
|
3216
|
+
doc = session.query(models.Document).filter(models.Document.document_id == document_id).first()
|
|
3217
|
+
if not doc or not doc.file_key:
|
|
3218
|
+
raise HTTPException(status_code=404, detail="File not found for this document.")
|
|
3219
|
+
|
|
3220
|
+
# Authorization check (same as download endpoint)
|
|
3221
|
+
if current_user.user_id == doc.author_id:
|
|
3222
|
+
pass
|
|
3223
|
+
elif doc.is_published:
|
|
3224
|
+
doc_group_ids = {g.group_id for g in doc.groups}
|
|
3225
|
+
user_group_ids = {g.group_id for g in current_user.groups}
|
|
3226
|
+
if not doc_group_ids & user_group_ids:
|
|
3227
|
+
raise HTTPException(status_code=403, detail="Not authorized to download this file.")
|
|
3228
|
+
else:
|
|
3229
|
+
raise HTTPException(status_code=403, detail="Not authorized to download this file.")
|
|
3230
|
+
|
|
3231
|
+
token = secrets.token_urlsafe(32)
|
|
3232
|
+
key = f"download_token:{token}"
|
|
3233
|
+
redis_client.setex(key, 300, document_id)
|
|
3234
|
+
print('Setting redis kv')
|
|
3235
|
+
print(key, document_id)
|
|
3236
|
+
return {"download_url": f"/documents/download/{token}"}
|
|
3237
|
+
|
|
3238
|
+
|
|
3239
|
+
@router.get("/documents/download/{token}")
|
|
3240
|
+
def download_document_with_token(
|
|
3241
|
+
token: str,
|
|
3242
|
+
storage: StorageProvider = Depends(get_storage),
|
|
3243
|
+
):
|
|
3244
|
+
key = f"download_token:{token}"
|
|
3245
|
+
print(f'Retrieving redis kv with key {key}')
|
|
3246
|
+
document_id = redis_client.get(key).decode('utf-8') if redis_client.get(key) else None
|
|
3247
|
+
print(f'Value {document_id}')
|
|
3248
|
+
if not document_id:
|
|
3249
|
+
raise HTTPException(status_code=403, detail="Invalid or expired token")
|
|
3250
|
+
redis_client.delete(key)
|
|
3251
|
+
with compair.Session() as session:
|
|
3252
|
+
doc = session.query(models.Document).filter(models.Document.document_id == document_id).first()
|
|
3253
|
+
if not doc or not doc.file_key:
|
|
3254
|
+
raise HTTPException(status_code=404, detail="File not found for this document.")
|
|
3255
|
+
|
|
3256
|
+
safe_title = sanitize_filename(doc.title or 'file')
|
|
3257
|
+
try:
|
|
3258
|
+
file_obj, content_type = storage.get_file(doc.file_key)
|
|
3259
|
+
except FileNotFoundError as exc:
|
|
3260
|
+
raise HTTPException(status_code=404, detail="Stored file could not be located.") from exc
|
|
3261
|
+
|
|
3262
|
+
return StreamingResponse(
|
|
3263
|
+
file_obj,
|
|
3264
|
+
media_type=content_type,
|
|
3265
|
+
headers={
|
|
3266
|
+
"Content-Disposition": f"attachment; filename={safe_title or 'file'}"
|
|
3267
|
+
}
|
|
3268
|
+
)
|
|
3269
|
+
|
|
3270
|
+
|
|
3271
|
+
@router.get("/documents/{document_id}/download")
|
|
3272
|
+
def download_document_file(
|
|
3273
|
+
document_id: str,
|
|
3274
|
+
current_user: models.User = Depends(get_current_user),
|
|
3275
|
+
storage: StorageProvider = Depends(get_storage),
|
|
3276
|
+
):
|
|
3277
|
+
with compair.Session() as session:
|
|
3278
|
+
doc = session.query(models.Document).filter(models.Document.document_id == document_id).first()
|
|
3279
|
+
|
|
3280
|
+
safe_title = sanitize_filename(doc.title or 'file')
|
|
3281
|
+
|
|
3282
|
+
if not doc or not doc.file_key:
|
|
3283
|
+
raise HTTPException(status_code=404, detail="File not found for this document.")
|
|
3284
|
+
|
|
3285
|
+
# Authorization check
|
|
3286
|
+
if current_user.user_id == doc.author_id:
|
|
3287
|
+
pass
|
|
3288
|
+
elif doc.is_published:
|
|
3289
|
+
# Check group membership
|
|
3290
|
+
doc_group_ids = {g.group_id for g in doc.groups}
|
|
3291
|
+
user_group_ids = {g.group_id for g in current_user.groups}
|
|
3292
|
+
if not doc_group_ids & user_group_ids:
|
|
3293
|
+
raise HTTPException(status_code=403, detail="Not authorized to download this file.")
|
|
3294
|
+
|
|
3295
|
+
# Fetch file from R2 and stream to user
|
|
3296
|
+
try:
|
|
3297
|
+
file_obj, content_type = storage.get_file(doc.file_key)
|
|
3298
|
+
except FileNotFoundError as exc:
|
|
3299
|
+
raise HTTPException(status_code=404, detail="Stored file could not be located.") from exc
|
|
3300
|
+
|
|
3301
|
+
return StreamingResponse(
|
|
3302
|
+
file_obj,
|
|
3303
|
+
media_type=content_type,
|
|
3304
|
+
headers={
|
|
3305
|
+
"Content-Disposition": f"attachment; filename={safe_title or 'file'}"
|
|
3306
|
+
}
|
|
3307
|
+
)
|
|
3308
|
+
|
|
3309
|
+
|
|
3310
|
+
@router.post("/help-request")
|
|
3311
|
+
def submit_help_request(
|
|
3312
|
+
content: str = Form(...),
|
|
3313
|
+
current_user: models.User = Depends(get_current_user)
|
|
3314
|
+
):
|
|
3315
|
+
with compair.Session() as session:
|
|
3316
|
+
help_request = models.HelpRequest(
|
|
3317
|
+
user_id=current_user.user_id,
|
|
3318
|
+
content=content,
|
|
3319
|
+
datetime_created=datetime.now(timezone.utc)
|
|
3320
|
+
)
|
|
3321
|
+
session.add(help_request)
|
|
3322
|
+
session.commit()
|
|
3323
|
+
if not IS_CLOUD:
|
|
3324
|
+
raise HTTPException(status_code=501, detail="Help request emails require the Compair Cloud edition.")
|
|
3325
|
+
send_help_request_email.delay(help_request.request_id)
|
|
3326
|
+
return {"message": "Your request has been submitted. Our team will get back to you soon."}
|
|
3327
|
+
|
|
3328
|
+
|
|
3329
|
+
@router.post("/deactivate-account")
|
|
3330
|
+
def submit_deactivate_request(
|
|
3331
|
+
notice: str = Form(...),
|
|
3332
|
+
current_user: models.User = Depends(get_current_user)
|
|
3333
|
+
):
|
|
3334
|
+
with compair.Session() as session:
|
|
3335
|
+
deactivate_request = models.DeactivateRequest(
|
|
3336
|
+
user_id=current_user.user_id,
|
|
3337
|
+
notice=notice,
|
|
3338
|
+
datetime_created=datetime.now(timezone.utc)
|
|
3339
|
+
)
|
|
3340
|
+
session.add(deactivate_request)
|
|
3341
|
+
session.commit()
|
|
3342
|
+
if not IS_CLOUD:
|
|
3343
|
+
raise HTTPException(status_code=501, detail="Deactivate request emails require the Compair Cloud edition.")
|
|
3344
|
+
send_deactivate_request_email.delay(deactivate_request.request_id)
|
|
3345
|
+
return {"message": f"We’ve received your request and will delete your account and data shortly. If you change your mind, reach out within 24 hours at {EMAIL_USER}."}
|
|
3346
|
+
|
|
3347
|
+
|
|
3348
|
+
def create_fastapi_app():
|
|
3349
|
+
"""Backwards-compatible app factory for running this module directly."""
|
|
3350
|
+
from fastapi import FastAPI
|
|
3351
|
+
|
|
3352
|
+
fastapi_app = FastAPI()
|
|
3353
|
+
fastapi_app.include_router(router)
|
|
3354
|
+
return fastapi_app
|
|
3355
|
+
|
|
3356
|
+
|
|
3357
|
+
app = create_fastapi_app()
|