compair-core 0.4.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. compair_core/__init__.py +8 -0
  2. compair_core/api.py +3598 -0
  3. compair_core/compair/__init__.py +57 -0
  4. compair_core/compair/celery_app.py +31 -0
  5. compair_core/compair/default_groups.py +14 -0
  6. compair_core/compair/embeddings.py +141 -0
  7. compair_core/compair/feedback.py +368 -0
  8. compair_core/compair/logger.py +29 -0
  9. compair_core/compair/main.py +276 -0
  10. compair_core/compair/models.py +453 -0
  11. compair_core/compair/schema.py +146 -0
  12. compair_core/compair/tasks.py +106 -0
  13. compair_core/compair/utils.py +42 -0
  14. compair_core/compair_email/__init__.py +0 -0
  15. compair_core/compair_email/email.py +6 -0
  16. compair_core/compair_email/email_core.py +15 -0
  17. compair_core/compair_email/templates.py +6 -0
  18. compair_core/compair_email/templates_core.py +32 -0
  19. compair_core/db.py +64 -0
  20. compair_core/server/__init__.py +0 -0
  21. compair_core/server/app.py +97 -0
  22. compair_core/server/deps.py +77 -0
  23. compair_core/server/local_model/__init__.py +1 -0
  24. compair_core/server/local_model/app.py +87 -0
  25. compair_core/server/local_model/ocr.py +107 -0
  26. compair_core/server/providers/__init__.py +0 -0
  27. compair_core/server/providers/console_mailer.py +9 -0
  28. compair_core/server/providers/contracts.py +66 -0
  29. compair_core/server/providers/http_ocr.py +60 -0
  30. compair_core/server/providers/local_storage.py +28 -0
  31. compair_core/server/providers/noop_analytics.py +7 -0
  32. compair_core/server/providers/noop_billing.py +30 -0
  33. compair_core/server/providers/noop_ocr.py +10 -0
  34. compair_core/server/routers/__init__.py +0 -0
  35. compair_core/server/routers/capabilities.py +46 -0
  36. compair_core/server/settings.py +66 -0
  37. compair_core-0.4.12.dist-info/METADATA +136 -0
  38. compair_core-0.4.12.dist-info/RECORD +41 -0
  39. compair_core-0.4.12.dist-info/WHEEL +5 -0
  40. compair_core-0.4.12.dist-info/licenses/LICENSE +674 -0
  41. compair_core-0.4.12.dist-info/top_level.txt +1 -0
compair_core/api.py ADDED
@@ -0,0 +1,3598 @@
1
+ import hashlib
2
+ import os
3
+ import re
4
+ import requests
5
+ import secrets
6
+ import threading
7
+ import time
8
+ from datetime import datetime, timedelta, timezone
9
+ from typing import Any, Mapping, Optional, Tuple
10
+
11
+ import httpx
12
+ import psutil
13
+ from celery.result import AsyncResult
14
+ from fastapi import APIRouter, Body, Depends, File, Form, Header, HTTPException, Query, Request, UploadFile
15
+ from fastapi.responses import HTMLResponse, RedirectResponse, StreamingResponse
16
+ from fastapi.routing import APIRoute
17
+ from sqlalchemy import distinct, func, select, or_
18
+ from sqlalchemy.orm import joinedload, Session
19
+
20
+ from .server.deps import get_analytics, get_billing, get_ocr, get_settings_dependency, get_storage
21
+ from .server.providers.contracts import Analytics, BillingProvider, OCRProvider, StorageProvider
22
+ from .server.settings import Settings
23
+
24
+ from . import compair
25
+ from .compair import models, schema
26
+ from .compair.embeddings import create_embedding, Embedder
27
+ from .compair.logger import log_event
28
+ from .compair.utils import chunk_text, generate_verification_token, log_activity
29
+ from .compair_email.email import emailer, EMAIL_USER
30
+ from .compair_email.templates import (
31
+ ACCOUNT_VERIFY_TEMPLATE,
32
+ GROUP_INVITATION_TEMPLATE,
33
+ GROUP_JOIN_TEMPLATE,
34
+ INDIVIDUAL_INVITATION_TEMPLATE,
35
+ PASSWORD_RESET_TEMPLATE,
36
+ REFERRAL_CREDIT_TEMPLATE
37
+ )
38
+ from .compair.tasks import process_document_task as process_document_celery, send_feature_announcement_task, send_deactivate_request_email, send_help_request_email
39
+
40
+ try:
41
+ import redis # type: ignore
42
+ except ImportError: # pragma: no cover - optional dependency
43
+ redis = None
44
+
45
+
46
+ def _getenv(*names: str, default: Optional[str] = None) -> Optional[str]:
47
+ """Return the first populated environment variable in the provided list."""
48
+ for name in names:
49
+ if not name:
50
+ continue
51
+ value = os.getenv(name)
52
+ if value:
53
+ return value
54
+ return default
55
+
56
+ redis_url = _getenv("COMPAIR_REDIS_URL", "REDIS_URL")
57
+ redis_client = redis.Redis.from_url(redis_url) if (redis and redis_url) else None
58
+ #from compair.main import process_document
59
+
60
+ router = APIRouter()
61
+ core_router = APIRouter()
62
+ WEB_URL = os.environ.get("WEB_URL")
63
+ ADMIN_API_KEY = os.environ.get("ADMIN_API_KEY")
64
+
65
+ CLOUDFLARE_IMAGES_ACCOUNT = os.environ.get("CLOUDFLARE_IMAGES_ACCOUNT")
66
+ CLOUDFLARE_IMAGES_URL_ACCOUNT = os.environ.get("CLOUDFLARE_IMAGES_URL_ACCOUNT")
67
+ CLOUDFLARE_IMAGES_BASE_URL = f"https://imagedelivery.net/{CLOUDFLARE_IMAGES_URL_ACCOUNT}"
68
+ CLOUDFLARE_IMAGES_TOKEN = os.environ.get("CLOUDFLARE_IMAGES_TOKEN")
69
+ CLOUDFLARE_IMAGES_UPLOAD_URL = f"https://api.cloudflare.com/client/v4/accounts/{CLOUDFLARE_IMAGES_ACCOUNT}/images/v1"
70
+
71
+ GA4_MEASUREMENT_ID = _getenv("COMPAIR_GA4_MEASUREMENT_ID", "GA4_MEASUREMENT_ID")
72
+ GA4_API_SECRET = _getenv("COMPAIR_GA4_API_SECRET", "GA4_API_SECRET")
73
+
74
+ IS_CLOUD = os.getenv("COMPAIR_EDITION", "core").lower() == "cloud"
75
+ SINGLE_USER_SESSION_TTL = timedelta(days=365)
76
+
77
+
78
+ def _render_email(template: str, **context: str) -> str:
79
+ """Lightweight template renderer for {{placeholders}} found in email HTML."""
80
+ rendered = template
81
+ for key, value in context.items():
82
+ replacement = value or ""
83
+ rendered = rendered.replace(f"{{{{{key}}}}}", replacement)
84
+ return rendered
85
+
86
+
87
+ def _dispatch_process_document_task(
88
+ user_id: str,
89
+ doc_id: str,
90
+ doc_text: str,
91
+ generate_feedback: bool,
92
+ ):
93
+ task_callable = getattr(process_document_celery, "delay", None)
94
+ if callable(task_callable):
95
+ return task_callable(user_id, doc_id, doc_text, generate_feedback)
96
+ return process_document_celery(user_id, doc_id, doc_text, generate_feedback)
97
+
98
+
99
+ def _ensure_single_user(session: Session, settings: Settings) -> models.User:
100
+ """Create or fetch the singleton user used when authentication is disabled."""
101
+ changed = False
102
+ user = (
103
+ session.query(models.User)
104
+ .options(joinedload(models.User.groups))
105
+ .filter(models.User.username == settings.single_user_username)
106
+ .first()
107
+ )
108
+ if user is None:
109
+ now = datetime.now(timezone.utc)
110
+ user = models.User(
111
+ username=settings.single_user_username,
112
+ name=settings.single_user_name,
113
+ datetime_registered=now,
114
+ verification_token=None,
115
+ token_expiration=None,
116
+ )
117
+ user.set_password(secrets.token_urlsafe(16))
118
+ user.status = "active"
119
+ user.status_change_date = now
120
+ session.add(user)
121
+ session.flush()
122
+ admin = models.Administrator(user_id=user.user_id)
123
+ group = models.Group(
124
+ name=user.username,
125
+ datetime_created=now,
126
+ group_image=None,
127
+ category="Private",
128
+ description=f"Private workspace for {settings.single_user_name}",
129
+ visibility="private",
130
+ )
131
+ group.admins.append(admin)
132
+ user.groups = [group]
133
+ session.add_all([group, admin])
134
+ changed = True
135
+ else:
136
+ now = datetime.now(timezone.utc)
137
+ if user.status != "active":
138
+ user.status = "active"
139
+ user.status_change_date = now
140
+ changed = True
141
+ group = next((g for g in user.groups if g.name == user.username), None)
142
+ if group is None:
143
+ group = session.query(models.Group).filter(models.Group.name == user.username).first()
144
+ if group is None:
145
+ group = models.Group(
146
+ name=user.username,
147
+ datetime_created=now,
148
+ group_image=None,
149
+ category="Private",
150
+ description=f"Private workspace for {user.name}",
151
+ visibility="private",
152
+ )
153
+ session.add(group)
154
+ changed = True
155
+ if group not in user.groups:
156
+ user.groups.append(group)
157
+ changed = True
158
+ admin = session.query(models.Administrator).filter(models.Administrator.user_id == user.user_id).first()
159
+ if admin is None:
160
+ admin = models.Administrator(user_id=user.user_id)
161
+ session.add(admin)
162
+ changed = True
163
+ if admin not in group.admins:
164
+ group.admins.append(admin)
165
+ changed = True
166
+
167
+ if changed:
168
+ session.commit()
169
+ user = (
170
+ session.query(models.User)
171
+ .options(joinedload(models.User.groups))
172
+ .filter(models.User.username == settings.single_user_username)
173
+ .first()
174
+ )
175
+ if user is None:
176
+ raise RuntimeError("Failed to initialize the local Compair user.")
177
+ user.groups # ensure relationship is loaded before detaching
178
+ return user
179
+
180
+
181
+ def _ensure_single_user_session(session: Session, user: models.User) -> models.Session:
182
+ """Return a long-lived session token for the singleton user."""
183
+ now = datetime.now(timezone.utc)
184
+ existing = (
185
+ session.query(models.Session)
186
+ .filter(models.Session.user_id == user.user_id, models.Session.datetime_valid_until >= now)
187
+ .order_by(models.Session.datetime_valid_until.desc())
188
+ .first()
189
+ )
190
+ if existing:
191
+ return existing
192
+ token = secrets.token_urlsafe()
193
+ user_session = models.Session(
194
+ id=token,
195
+ user_id=user.user_id,
196
+ datetime_created=now,
197
+ datetime_valid_until=now + SINGLE_USER_SESSION_TTL,
198
+ )
199
+ session.add(user_session)
200
+ session.commit()
201
+ return user_session
202
+
203
+
204
+ def require_cloud(feature: str) -> None:
205
+ if not IS_CLOUD:
206
+ raise HTTPException(status_code=501, detail=f"{feature} is only available in the Compair Cloud edition.")
207
+
208
+
209
+ def _user_plan(user: models.User) -> str:
210
+ return getattr(user, "plan", "free") or "free"
211
+
212
+
213
+ def _user_team(user: models.User):
214
+ return getattr(user, "team", None)
215
+
216
+
217
+ def _trial_expiration(user: models.User) -> datetime | None:
218
+ return getattr(user, "trial_expiration_date", None)
219
+
220
+
221
+ HAS_TEAM = hasattr(models, "Team")
222
+ HAS_ACTIVITY = hasattr(models, "Activity")
223
+ HAS_REFERRALS = hasattr(models.User, "referral_code")
224
+ HAS_BILLING = hasattr(models.User, "stripe_customer_id")
225
+ HAS_TRIALS = hasattr(models.User, "trial_expiration_date")
226
+ HAS_REDIS = redis_client is not None
227
+
228
+
229
+ def require_feature(flag: bool, feature: str) -> None:
230
+ if not flag and not IS_CLOUD:
231
+ raise HTTPException(status_code=501, detail=f"{feature} is only available in the Compair Cloud edition.")
232
+
233
+ def get_current_user(auth_token: str | None = Header(None)):
234
+ settings = get_settings_dependency()
235
+ if not settings.require_authentication:
236
+ with compair.Session() as session:
237
+ return _ensure_single_user(session, settings)
238
+ if not auth_token:
239
+ raise HTTPException(status_code=401, detail="Missing session token")
240
+ with compair.Session() as session:
241
+ user_session = session.query(models.Session).filter(models.Session.id == auth_token).first()
242
+ if not user_session:
243
+ raise HTTPException(status_code=401, detail="Invalid or expired session token")
244
+ # Ensure datetime_valid_until is timezone-aware
245
+ valid_until = user_session.datetime_valid_until
246
+ if valid_until.tzinfo is None:
247
+ valid_until = valid_until.replace(tzinfo=timezone.utc)
248
+ if valid_until < datetime.now(timezone.utc):
249
+ raise HTTPException(status_code=401, detail="Invalid or expired session token")
250
+ user = session.query(
251
+ models.User
252
+ ).filter(
253
+ models.User.user_id == user_session.user_id
254
+ ).options(
255
+ joinedload(models.User.groups)
256
+ ).first()
257
+ if not user:
258
+ raise HTTPException(status_code=404, detail="User not found")
259
+ return user
260
+
261
+ def get_current_user_with_access_to_doc(
262
+ document_id: str,
263
+ current_user: models.User = Depends(get_current_user)
264
+ ) -> models.User:
265
+ with compair.Session() as session:
266
+ doc = session.query(models.Document).filter(models.Document.document_id == document_id).first()
267
+ if not doc:
268
+ raise HTTPException(status_code=404, detail="Document not found")
269
+ # Allow if user is author
270
+ if doc.author_id == current_user.user_id:
271
+ return current_user
272
+ # Allow if user is in any group that has access to the document
273
+ doc_group_ids = {g.group_id for g in doc.groups}
274
+ user_group_ids = {g.group_id for g in current_user.groups}
275
+ if doc_group_ids & user_group_ids:
276
+ return current_user
277
+ # Optionally, allow if document is published
278
+ if doc.is_published:
279
+ return current_user
280
+ raise HTTPException(status_code=403, detail="Not authorized to access this document")
281
+
282
+ def log_service_resource_metrics(service_name="backend"):
283
+ def log():
284
+ try:
285
+ p = psutil.Process(os.getpid())
286
+ mem_mb = round(p.memory_info().rss / 1024 / 1024, 2)
287
+ cpu_percent = p.cpu_percent(interval=1)
288
+ log_event("service_resource", service=service_name, memory_mb=mem_mb, cpu_percent=cpu_percent)
289
+ except Exception as e:
290
+ print(f"[Resource Log Error] {e}")
291
+ finally:
292
+ # Re-schedule logging in 5 minutes
293
+ threading.Timer(300, log).start()
294
+ log()
295
+
296
+ log_service_resource_metrics(service_name="backend") # or "frontend"
297
+
298
+ # Run via: fastapi dev api.py
299
+
300
+
301
+ @router.post("/login")
302
+ def login(request: schema.LoginRequest) -> dict:
303
+ settings = get_settings_dependency()
304
+ with compair.Session() as session:
305
+ if not settings.require_authentication:
306
+ user = _ensure_single_user(session, settings)
307
+ user_session = _ensure_single_user_session(session, user)
308
+ return {
309
+ "user_id": user.user_id,
310
+ "username": user.username,
311
+ "name": user.name,
312
+ "status": user.status,
313
+ "role": user.role,
314
+ "auth_token": user_session.id,
315
+ }
316
+ user = session.query(models.User).filter(models.User.username == request.username).first()
317
+ if not user or not user.check_password(request.password):
318
+ raise HTTPException(status_code=401, detail="Invalid credentials")
319
+ if user.status == 'inactive':
320
+ raise HTTPException(status_code=403, detail="User account is not verified")
321
+ now = datetime.now(tz=timezone.utc)
322
+ user_session = models.Session(
323
+ id=secrets.token_urlsafe(),
324
+ user_id=user.user_id,
325
+ datetime_created=now,
326
+ datetime_valid_until=now + timedelta(days=1),
327
+ )
328
+ session.add(user_session)
329
+ session.commit()
330
+ return {
331
+ "user_id": user.user_id,
332
+ "username": user.username,
333
+ "name": user.name,
334
+ "status": user.status,
335
+ "role": user.role,
336
+ "auth_token": user_session.id, # Return the session token here
337
+ }
338
+
339
+
340
+ @router.get("/username_exists")
341
+ def username_exists(username: str) -> dict:
342
+ with compair.Session() as session:
343
+ exists = session.query(models.User).filter(models.User.username == username).first() is not None
344
+ return {"exists": exists}
345
+
346
+
347
+ @router.get("/load_user")
348
+ def load_user(
349
+ username: str,
350
+ ) -> schema.User | None:
351
+ with compair.Session() as session:
352
+ q = select(models.User).filter(
353
+ models.User.username.match(username)
354
+ )
355
+ user = session.execute(q).fetchone()
356
+ if user is None:
357
+ return
358
+ user = user[0]
359
+ if user.groups is None:
360
+ user.groups = []
361
+ print(f'User: {user}')
362
+ return user
363
+
364
+
365
+ @router.get("/load_user_plan")
366
+ def load_user(
367
+ current_user: models.User = Depends(get_current_user)
368
+ ) -> dict:
369
+ return {'plan': _user_plan(current_user)}
370
+
371
+
372
+ @router.get("/load_user_by_id")
373
+ def load_user(
374
+ user_id: str,
375
+ ) -> schema.User | None:
376
+ with compair.Session() as session:
377
+ q = select(models.User).filter(
378
+ models.User.user_id==user_id
379
+ )
380
+ user = session.execute(q).fetchone()
381
+ if user is None:
382
+ return
383
+ user = user[0]
384
+ if user.groups is None:
385
+ user.groups = []
386
+ print(f'User: {user}')
387
+ return user
388
+
389
+
390
+ @router.get("/load_user_files")
391
+ def load_user_files(
392
+ connection_id: str,
393
+ page: int = 1,
394
+ page_size: int = 10,
395
+ filter_type: str | None = None,
396
+ current_user: models.User = Depends(get_current_user)
397
+ ) -> dict:
398
+ with compair.Session() as session:
399
+ # Validate connection
400
+ connection = session.query(models.User).filter(models.User.user_id == connection_id).first()
401
+ if not connection:
402
+ return {"files": [], "message": "User or connection not found."}
403
+
404
+ # Security: must share a group
405
+ shared_group_ids = set(g.group_id for g in current_user.groups).intersection(g.group_id for g in connection.groups)
406
+ if not shared_group_ids:
407
+ return {"files": [], "message": "You do not have permission to view this user's affiliations, as you do not share a group together."}
408
+
409
+ # Privacy: check setting
410
+ if connection.hide_affiliations:
411
+ return {"files": [], "message": f"{connection.username} has set their profile to private."}
412
+
413
+ now = datetime.now(timezone.utc)
414
+ week_ago = now - timedelta(days=7)
415
+ # Fetch documents belonging to shared or public groups
416
+ q = select(models.Document).join(models.Document.groups).filter(
417
+ models.Document.user_id == connection_id,
418
+ models.Group.group_id.in_(shared_group_ids) | # Shared groups
419
+ (
420
+ (models.Group.visibility == "public") & # Public groups
421
+ models.Group.users.any(models.User.user_id == connection_id) # Associated with the connection
422
+ )
423
+ ).options(
424
+ joinedload(models.Document.groups),
425
+ joinedload(models.Document.user).joinedload(models.User.groups)
426
+ )
427
+
428
+ # --- Filter logic ---
429
+ if filter_type == "published":
430
+ q = q.filter(models.Document.is_published == True)
431
+ elif filter_type == "unpublished":
432
+ q = q.filter(models.Document.is_published == False)
433
+ elif filter_type == "recently_updated":
434
+ # Documents updated OR with a note in the past week
435
+ q = q.outerjoin(models.Document.notes).filter(
436
+ or_(
437
+ models.Document.datetime_modified >= week_ago,
438
+ models.Note.datetime_created >= week_ago
439
+ )
440
+ )
441
+ elif filter_type == "recently_compaired":
442
+ # Documents with feedback in the past week
443
+ q = q.join(models.Document.chunks).join(models.Chunk.feedbacks).filter(
444
+ models.Feedback.timestamp >= week_ago
445
+ )
446
+ # Default: all, sorted by last update
447
+ q = q.order_by(models.Document.datetime_modified.desc())
448
+
449
+ documents = session.execute(q).unique().fetchall()
450
+
451
+ if documents is None or len(documents)==0:
452
+ return {
453
+ "documents": [],
454
+ "total_count": 0,
455
+ "message": None
456
+ }
457
+
458
+ total_count = len(documents)
459
+ # Paging
460
+ offset = (page - 1) * page_size
461
+ documents = session.execute(q.order_by(models.Document.datetime_created.desc()).offset(offset).limit(page_size)).unique().fetchall()
462
+
463
+ files = [d[0] for d in documents] if documents else []
464
+ return {
465
+ "files": [schema.Document.model_validate(f) for f in files],
466
+ "message": None,
467
+ "total_count": total_count,
468
+ }
469
+
470
+
471
+ @router.get("/load_user_groups")
472
+ def load_user_groups(
473
+ connection_id: str,
474
+ page: int = 1,
475
+ page_size: int = 10,
476
+ filter_type: str | None = None,
477
+ current_user: models.User = Depends(get_current_user)
478
+ ) -> dict:
479
+ with compair.Session() as session:
480
+ now = datetime.now(timezone.utc)
481
+ week_ago = now - timedelta(days=7)
482
+ # Validate connection
483
+ connection = session.query(models.User).filter(models.User.user_id == connection_id).first()
484
+ if not connection:
485
+ return {"groups": [], "message": "User or connection not found."}
486
+
487
+ # Check if the connection is valid
488
+ shared_group_ids = set([g.group_id for g in current_user.groups]).intersection([g.group_id for g in connection.groups])
489
+ if not shared_group_ids:
490
+ return {"groups": [], "message": "You do not have permission to view this user's affiliations, as you do not share a group together."}
491
+
492
+ # Privacy: check setting
493
+ if connection.hide_affiliations:
494
+ return {"groups": [], "message": f"{connection.username} has set their profile to private."}
495
+
496
+ # Fetch shared or public groups of the connection
497
+ groups_query = session.query(models.Group).filter(
498
+ models.Group.group_id.in_(shared_group_ids) | # Shared groups
499
+ (
500
+ (models.Group.visibility == "public") & # Public groups
501
+ models.Group.users.any(models.User.user_id == connection_id) # Associated with the connection
502
+ )
503
+ ).options(
504
+ joinedload(models.Group.users),
505
+ joinedload(models.Group.documents)
506
+ )
507
+
508
+ # --- Filter logic (match /load_groups) ---
509
+ if filter_type == "internal":
510
+ groups_query = groups_query.filter(models.Group.visibility == "internal")
511
+ elif filter_type == "public":
512
+ groups_query = groups_query.filter(models.Group.visibility == "public")
513
+ elif filter_type == "private":
514
+ groups_query = groups_query.filter(models.Group.visibility == "private")
515
+ elif filter_type == "recently_updated":
516
+ groups_query = groups_query.join(models.Group.documents).filter(
517
+ models.Document.datetime_created >= week_ago
518
+ )
519
+ else:
520
+ groups_query = groups_query.order_by(models.Group.name.asc())
521
+
522
+ total_count = groups_query.count()
523
+ offset = (page - 1) * page_size
524
+ groups = groups_query.offset(offset).limit(page_size).all()
525
+
526
+ result = [
527
+ {
528
+ "group_id": group.group_id,
529
+ "name": group.name,
530
+ "datetime_created": group.datetime_created,
531
+ "group_image": group.group_image,
532
+ "category": group.category,
533
+ "description": group.description,
534
+ "visibility": group.visibility,
535
+ "document_count": getattr(group, "document_count", None),
536
+ "user_count": getattr(group, "user_count", None),
537
+ "first_three_user_profile_images": getattr(group, "first_three_user_profile_images", None)
538
+ }
539
+ for group in groups
540
+ ]
541
+ return {"groups": result, "message": None, "total_count": total_count}
542
+
543
+
544
+ @router.get("/load_user_status")
545
+ def load_user_status(
546
+ current_user: models.User = Depends(get_current_user)
547
+ ) -> str:
548
+ user_status = 'inactive'
549
+ with compair.Session() as session:
550
+ user_status = current_user.status
551
+ return user_status
552
+
553
+
554
+ @router.get("/load_user_status_date")
555
+ def load_user_status(
556
+ current_user: models.User = Depends(get_current_user)
557
+ ) -> datetime:
558
+ if not (HAS_TRIALS or HAS_BILLING) and not IS_CLOUD:
559
+ raise HTTPException(status_code=501, detail="User status dates are only tracked in the Compair Cloud edition.")
560
+ with compair.Session() as session:
561
+ user_status = current_user.status
562
+ if user_status=='active':
563
+ require_feature(HAS_BILLING, "Billing history")
564
+ user_status_date = current_user.last_payment_date
565
+ elif user_status=='trial':
566
+ require_feature(HAS_TRIALS, "Trial management")
567
+ user_status_date = current_user.trial_expiration_date
568
+ elif user_status=='suspended':
569
+ user_status_date = current_user.status_change_date
570
+ else:
571
+ raise HTTPException(status_code=403, detail='User Inactive')
572
+ return user_status_date
573
+
574
+
575
+ @router.get("/load_referral_credits")
576
+ def load_referral_credits(
577
+ current_user: models.User = Depends(get_current_user)
578
+ ) -> Tuple[int, int]:
579
+ if not HAS_REFERRALS and not IS_CLOUD:
580
+ raise HTTPException(status_code=501, detail="Referral credits are only available in the Compair Cloud edition.")
581
+ with compair.Session() as session:
582
+ referral_credits_earned = current_user.referral_credits
583
+ referral_credits_pending = current_user.pending_referral_credits
584
+ return (referral_credits_earned, referral_credits_pending)
585
+
586
+
587
+ def create_user(
588
+ username: str,
589
+ name: str,
590
+ password: str,
591
+ session: Session,
592
+ groups: list[str] | None = None,
593
+ referral_code: str = None
594
+ ):
595
+ token, expiration = generate_verification_token()
596
+ existing_user = session.query(models.User).filter(models.User.username == username).first()
597
+ if existing_user:
598
+ raise HTTPException(status_code=400, detail="Email already in use")
599
+
600
+ user = models.User(
601
+ username=username,
602
+ name=name,
603
+ datetime_registered=datetime.now(),
604
+ verification_token=token.lower(),
605
+ token_expiration=expiration,
606
+ )
607
+ user.set_password(password=password)
608
+ session.add(user)
609
+ session.commit()
610
+
611
+ if HAS_TEAM:
612
+ team_invitations = session.query(models.TeamInvitation).filter(
613
+ models.TeamInvitation.email == username,
614
+ models.TeamInvitation.status == "pending",
615
+ ).all()
616
+
617
+ for invitation in team_invitations:
618
+ emailer.connect()
619
+ emailer.send(
620
+ subject="You're Invited to Join a Team on Compair",
621
+ sender=EMAIL_USER,
622
+ receivers=[invitation.email],
623
+ html=f"""
624
+ <p>{invitation.inviter.name} has invited you to join their team on Compair!</p>
625
+ <p>Click <a href="https://{WEB_URL}/accept-invitation?token={invitation.invitation_id}">here</a> to join.</p>
626
+ """
627
+ )
628
+ invitation.status = "sent"
629
+ session.commit()
630
+
631
+ if groups is not None:
632
+ user.groups = load_groups_by_ids(groups)
633
+ else:
634
+ group = models.Group(
635
+ name=username,
636
+ datetime_created=datetime.now(),
637
+ group_image=None,
638
+ category="Private",
639
+ description=f"A private group for {username}",
640
+ visibility="private"
641
+ )
642
+ admin = models.Administrator(user_id=user.user_id)
643
+ session.add(admin)
644
+ session.add(group)
645
+ session.commit()
646
+ group.admins.append(admin)
647
+ user.groups = [group]
648
+
649
+ # Track referral if a code is provided
650
+ if referral_code:
651
+ require_feature(HAS_REFERRALS, "Referral program")
652
+ print(f'Got to backend referral code: {referral_code}')
653
+ referrer = session.query(models.User).filter(models.User.referral_code == referral_code).first()
654
+ if referrer and hasattr(referrer, "referral_credits"):
655
+ max_credits = 3
656
+ if referrer.referral_credits <= max_credits * 10: # $10 per credit
657
+ if hasattr(user, "referred_by"):
658
+ user.referred_by = referrer.user_id # Store who referred them
659
+ if hasattr(referrer, "pending_referral_credits"):
660
+ referrer.pending_referral_credits += 10 # Add pending credit
661
+ session.commit()
662
+
663
+ session.add(user)
664
+ session.commit()
665
+ try:
666
+ analytics.track("user_signup", user.user_id)
667
+ except Exception as exc:
668
+ print(f"analytics track failed: {exc}")
669
+ return user
670
+
671
+
672
+ def _activate_user_account(
673
+ session: Session,
674
+ user: models.User,
675
+ *,
676
+ send_group_invites: bool = True,
677
+ ) -> None:
678
+ user.status = "trial" if HAS_TRIALS else "active"
679
+ user.status_change_date = datetime.now(timezone.utc)
680
+ if HAS_TRIALS:
681
+ user.trial_expiration_date = datetime.now(timezone.utc) + timedelta(days=30)
682
+ user.verification_token = None
683
+ session.commit()
684
+
685
+ if not send_group_invites:
686
+ return
687
+
688
+ pending_invitations = session.query(models.GroupInvitation).filter(
689
+ models.GroupInvitation.email == user.username,
690
+ models.GroupInvitation.status == "pending"
691
+ ).all()
692
+ for invitation in pending_invitations:
693
+ invitation.status = "sent"
694
+ if not invitation.token:
695
+ invitation.token = secrets.token_urlsafe(32).lower()
696
+ invitation.datetime_expiration = datetime.now(timezone.utc) + timedelta(days=7)
697
+ session.commit()
698
+
699
+ group = invitation.group
700
+ inviter = invitation.inviter
701
+ invitation_link = f"http://{WEB_URL}/accept-group-invitation?token={invitation.token}&user_id={user.user_id}"
702
+ emailer.connect()
703
+ emailer.send(
704
+ subject="You’re Invited to Join a Group on Compair",
705
+ sender=EMAIL_USER,
706
+ receivers=[user.username],
707
+ html=GROUP_INVITATION_TEMPLATE.replace(
708
+ "{{inviter_name}}", inviter.name
709
+ ).replace(
710
+ "{{group_name}}", group.name
711
+ ).replace(
712
+ "{{invitation_link}}", invitation_link
713
+ )
714
+ )
715
+
716
+
717
+ @router.get("/load_session")
718
+ def load_session(auth_token: str | None = None) -> schema.Session | None:
719
+ settings = get_settings_dependency()
720
+ if not settings.require_authentication:
721
+ with compair.Session() as session:
722
+ user = _ensure_single_user(session, settings)
723
+ session_model = _ensure_single_user_session(session, user)
724
+ return schema.Session.model_validate(session_model, from_attributes=True)
725
+ with compair.Session() as session:
726
+ if not auth_token:
727
+ raise HTTPException(status_code=400, detail="auth_token is required when authentication is enabled.")
728
+ user_session = session.query(models.Session).filter(models.Session.id == auth_token).first()
729
+ if not user_session:
730
+ raise HTTPException(status_code=404, detail="Session not found")
731
+ valid_until = user_session.datetime_valid_until
732
+ if valid_until.tzinfo is None:
733
+ valid_until = valid_until.replace(tzinfo=timezone.utc)
734
+ if valid_until < datetime.now(timezone.utc):
735
+ raise HTTPException(status_code=401, detail="Invalid or expired session token")
736
+ return schema.Session.model_validate(user_session, from_attributes=True)
737
+
738
+
739
+ @router.post("/update_user")
740
+ def update_user(
741
+ name: str = Form(None),
742
+ role: str = Form(None),
743
+ group_ids: list[str] = Form(None),
744
+ include_own_documents_in_feedback: str = Form(None),
745
+ default_publish: str = Form(None),
746
+ preferred_feedback_length: str = Form(None),
747
+ hide_affiliations: str = Form(None),
748
+ current_user: models.User = Depends(get_current_user)
749
+ ):
750
+ with compair.Session() as session:
751
+ if name is not None:
752
+ current_user.name = name
753
+ if role is not None:
754
+ current_user.role = role
755
+ if group_ids is not None:
756
+ groups = load_groups_by_ids(group_ids)
757
+ groups = [g for g in groups if g not in current_user.groups]
758
+ current_user.groups.extend(groups)
759
+ if include_own_documents_in_feedback is not None:
760
+ # Convert string to bool
761
+ current_user.include_own_documents_in_feedback = include_own_documents_in_feedback.lower() == "true"
762
+ if default_publish is not None:
763
+ current_user.default_publish = default_publish.lower() == "true"
764
+ if preferred_feedback_length is not None:
765
+ # Lock to Brief for the time being
766
+ preferred_feedback_length = 'Brief'
767
+ current_user.preferred_feedback_length = preferred_feedback_length
768
+ if hide_affiliations is not None:
769
+ current_user.hide_affiliations = hide_affiliations.lower() == "true"
770
+ session.add(current_user)
771
+ session.commit()
772
+
773
+ @router.get("/update_session_duration")
774
+ def update_session_duration(
775
+ user_session: schema.Session,
776
+ new_valid_until: datetime,
777
+ ) -> None:
778
+ with compair.Session() as session:
779
+ user_session = session.query(
780
+ models.Session
781
+ ).filter(
782
+ models.Session.id == user_session.id
783
+ ).first()
784
+ user_session.update({'datetime_valid_until': new_valid_until})
785
+ session.commit()
786
+
787
+
788
+ @router.get("/delete_user")
789
+ def delete_user(
790
+ current_user: models.User = Depends(get_current_user)
791
+ ):
792
+ settings = get_settings_dependency()
793
+ if not settings.require_authentication:
794
+ raise HTTPException(status_code=403, detail="Deleting the local user is not supported when authentication is disabled.")
795
+ with compair.Session() as session:
796
+ current_user.delete()
797
+ session.commit()
798
+
799
+
800
+ @router.get("/load_connections")
801
+ def load_connections(
802
+ page: int = 1,
803
+ page_size: int = 10,
804
+ filter_type: str | None = None,
805
+ current_user: models.User = Depends(get_current_user)
806
+ ) -> dict | None:
807
+ with compair.Session() as session:
808
+ now = datetime.now(timezone.utc)
809
+ week_ago = now - timedelta(days=7)
810
+
811
+ # Get all groups the user belongs to
812
+ groups = session.query(models.Group).options(joinedload(models.Group.users)).filter(
813
+ models.Group.group_id.in_([g.group_id for g in current_user.groups])
814
+ ).all()
815
+ if not groups:
816
+ return {"connections": [], "total_count": 0}
817
+
818
+ # Collect all user IDs from the groups
819
+ connection_ids = set()
820
+ for group in groups:
821
+ for group_user in group.users:
822
+ if group_user.user_id != current_user.user_id: # Exclude the requesting user
823
+ connection_ids.add(group_user.user_id)
824
+
825
+ # Fetch the User objects for the collected IDs
826
+ q = session.query(models.User).filter(models.User.user_id.in_(connection_ids))
827
+
828
+ # --- Filter logic ---
829
+ if filter_type == "recently_active":
830
+ q = q.join(models.User.activities).filter(
831
+ models.Activity.action == "create",
832
+ models.Activity.timestamp >= week_ago
833
+ )
834
+ elif filter_type == "recently_compaired":
835
+ # Their doc was used for feedback on your doc, or vice versa, in the past week
836
+ q = q.join(models.User.documents).join(models.Document.chunks).join(models.Chunk.feedbacks).filter(
837
+ models.Feedback.timestamp >= week_ago
838
+ )
839
+ else:
840
+ q = q.order_by(models.User.name.asc())
841
+
842
+ total_count = q.count()
843
+ offset = (page - 1) * page_size
844
+ connections = q.order_by(models.User.datetime_registered.desc()).offset(offset).limit(page_size).all()
845
+
846
+ # Convert users to dictionary format
847
+ return {
848
+ "connections": [
849
+ {
850
+ "user_id": connection.user_id,
851
+ "username": connection.username,
852
+ "name": connection.name,
853
+ "datetime_registered": connection.datetime_registered,
854
+ "status": connection.status,
855
+ "profile_image": connection.profile_image,
856
+ "role": connection.role,
857
+ }
858
+ for connection in connections
859
+ ],
860
+ "total_count": total_count
861
+ }
862
+
863
+
864
+ @router.get("/all_group_categories")
865
+ def all_group_categories(
866
+ current_user: models.User = Depends(get_current_user)
867
+ ):
868
+ """Return all unique group categories."""
869
+ with compair.Session() as session:
870
+ categories = (
871
+ session.query(distinct(models.Group.category))
872
+ .order_by(models.Group.category.asc())
873
+ .all()
874
+ )
875
+ # Flatten list of tuples and filter out None/empty
876
+ categories = [c[0] for c in categories if c[0] and c[0]!='Compair']
877
+ return {"categories": categories}
878
+
879
+
880
+ @router.get("/load_groups")
881
+ def load_groups(
882
+ user_id: str | None = None,
883
+ page: int = 1,
884
+ page_size: int = 10,
885
+ filter_type: str | None = None,
886
+ category: str | None = None,
887
+ visibility: str | None = None,
888
+ sort: str | None = None,
889
+ query: str | None = None,
890
+ own_groups_only: bool = False,
891
+ current_user: models.User = Depends(get_current_user)
892
+ ) -> dict | None:
893
+ with compair.Session() as session:
894
+ # --- User-based group selection ---
895
+ if user_id is None:
896
+ q = session.query(models.Group).options(
897
+ joinedload(models.Group.users),
898
+ joinedload(models.Group.documents)
899
+ ).filter(
900
+ models.Group.visibility != 'private'
901
+ )
902
+ else:
903
+ user = session.query(models.User).options(
904
+ joinedload(models.User.groups).joinedload(models.Group.users),
905
+ joinedload(models.User.groups).joinedload(models.Group.documents)
906
+ ).filter(
907
+ models.User.user_id == current_user.user_id
908
+ ).first()
909
+ user_group_ids = [g.group_id for g in user.groups if g.category!='Compair']
910
+
911
+ invited_group_ids = set()
912
+ invitations = session.query(models.GroupInvitation).filter(
913
+ models.GroupInvitation.email == user.username,
914
+ models.GroupInvitation.status == "sent"
915
+ ).all()
916
+ invited_group_ids.update([i.group_id for i in invitations])
917
+
918
+ accessible_group_ids = set(user_group_ids) | invited_group_ids
919
+
920
+ if own_groups_only:
921
+ # Only groups the user is a member of or invited
922
+ q = session.query(models.Group).filter(
923
+ models.Group.group_id.in_(accessible_group_ids)
924
+ )
925
+ else:
926
+ # All groups user can access: public, or private/internal if a member, or groups with an invitation
927
+ q = session.query(models.Group).filter(
928
+ (models.Group.visibility == "public") |
929
+ (models.Group.group_id.in_(accessible_group_ids))
930
+ )
931
+
932
+ # --- Filtering ---
933
+ if category and category.lower() != "all":
934
+ q = q.filter(models.Group.category == category)
935
+ if visibility and visibility.lower() != "all":
936
+ q = q.filter(models.Group.visibility == visibility)
937
+ if filter_type == "joined" and user_id:
938
+ #user = session.query(models.User).filter(models.User.user_id == user_id).first()
939
+ #q = q.filter(models.Group.group_id.in_([g.group_id for g in user.groups]))
940
+ # Already constrained above if own_groups_only is True
941
+ pass
942
+ if filter_type == "pending" and user_id:
943
+ # Groups with pending join requests or invitations for this user
944
+ user = session.query(models.User).filter(models.User.user_id == current_user.user_id).first()
945
+ pending_group_ids = set()
946
+ invitations = session.query(models.GroupInvitation).filter(
947
+ models.GroupInvitation.email == user.username,
948
+ models.GroupInvitation.status == "sent"
949
+ ).all()
950
+ pending_group_ids.update([i.group_id for i in invitations])
951
+ q = q.filter(models.Group.group_id.in_(pending_group_ids))
952
+ if query:
953
+ q = q.filter(models.Group.name.ilike(f"%{query}%"))
954
+ # --- Sorting ---
955
+ if sort == "popular":
956
+ q = q.outerjoin(models.Group.users).group_by(models.Group.group_id).order_by(func.count(models.User.user_id).desc())
957
+ elif sort == "recently_updated":
958
+ q = q.outerjoin(models.Group.documents).group_by(models.Group.group_id).order_by(func.max(models.Document.datetime_modified).desc())
959
+ elif sort == "recently_created":
960
+ q = q.order_by(models.Group.datetime_created.desc())
961
+ else:
962
+ q = q.order_by(models.Group.name.asc())
963
+ # --- Paging ---
964
+ total_count = q.count()
965
+ offset = (page - 1) * page_size
966
+ groups = q.offset(offset).limit(page_size).all()
967
+
968
+ result = [
969
+ {
970
+ "group_id": group.group_id,
971
+ "name": group.name,
972
+ "datetime_created": group.datetime_created,
973
+ "group_image": group.group_image,
974
+ "category": group.category,
975
+ "description": group.description,
976
+ "visibility": group.visibility,
977
+ "document_count": group.document_count,
978
+ "user_count": group.user_count,
979
+ "first_three_user_profile_images": group.first_three_user_profile_images
980
+ }
981
+ for group in groups
982
+ ]
983
+ return {"groups": result, "total_count": total_count}
984
+
985
+
986
+ def load_groups_by_ids(group_ids: list[str]) -> list[schema.Group]:
987
+ with compair.Session() as session:
988
+ q = session.query(models.Group).filter(
989
+ models.Group.group_id.in_(group_ids)
990
+ )
991
+ return q.all() # Returns list of Group objects directly
992
+
993
+
994
+ @router.get("/load_group")
995
+ def load_group(
996
+ name: str | None = None,
997
+ group_id: str | None = None
998
+ ) -> schema.Group | None:
999
+ if (name is not None) or (group_id is not None):
1000
+ with compair.Session() as session:
1001
+ if group_id is not None:
1002
+ q = select(models.Group).filter(
1003
+ models.Group.group_id==group_id
1004
+ )
1005
+ else:
1006
+ q = select(models.Group).filter(
1007
+ models.Group.name.match(name)
1008
+ )
1009
+ group = session.execute(q).fetchone()
1010
+ if group is None:
1011
+ return None
1012
+ return group[0]
1013
+
1014
+
1015
+ def notify_group_admins(
1016
+ group: models.Group,
1017
+ user_id: str
1018
+ ):
1019
+ """Send an email notification to group admins."""
1020
+ with compair.Session() as session:
1021
+ admin_emails = [admin.user.username for admin in group.admins]
1022
+ if len(admin_emails) == 0:
1023
+ print("No admins found for group:", group.name)
1024
+ return
1025
+ user = session.query(models.User).filter(models.User.user_id == user_id).first()
1026
+ emailer.connect()
1027
+ print("Admin emails:", admin_emails)
1028
+ emailer.send(
1029
+ subject="Group Join Request",
1030
+ sender=EMAIL_USER,
1031
+ receivers=admin_emails,
1032
+ html=GROUP_JOIN_TEMPLATE.replace(
1033
+ "{{ user_name }}", user.username
1034
+ ).replace(
1035
+ "{{ group_name }}", group.name
1036
+ ).replace(
1037
+ "{{ admin_panel_url }}", f"http://{WEB_URL}/admin/groups"
1038
+ )
1039
+ )
1040
+
1041
+ @router.post("/join_group")
1042
+ def join_group(
1043
+ group_id: str = Form(...),
1044
+ current_user: models.User = Depends(get_current_user)
1045
+ ):
1046
+ return join_group_direct(
1047
+ user_id=current_user.user_id,
1048
+ group_id=group_id
1049
+ )
1050
+
1051
+ def join_group_direct(
1052
+ user_id: str,
1053
+ group_id: str
1054
+ ):
1055
+ """Join a group based on its visibility."""
1056
+ print("1")
1057
+ with compair.Session() as session:
1058
+ group = session.query(models.Group).filter(models.Group.group_id == group_id).first()
1059
+ print(group)
1060
+ if not group:
1061
+ raise HTTPException(status_code=404, detail="Group not found")
1062
+ print(group.visibility)
1063
+ if group.visibility in ["public", "internal"]:
1064
+ user = session.query(models.User).filter(models.User.user_id == user_id).first()
1065
+ # Look for any existing invitations associated with this group
1066
+ invitations = session.query(models.GroupInvitation).filter(
1067
+ models.GroupInvitation.group_id == group_id,
1068
+ models.GroupInvitation.email == user.username,
1069
+ models.GroupInvitation.status == "sent"
1070
+ ).all()
1071
+ for invitation in invitations:
1072
+ invitation.status = "accepted"
1073
+
1074
+ if group.visibility == "public":
1075
+ group.users.append(user)
1076
+ session.commit()
1077
+ log_activity(
1078
+ session=session,
1079
+ user_id=user_id,
1080
+ group_id=group.group_id,
1081
+ action="join",
1082
+ object_id=group.group_id,
1083
+ object_name=group.name,
1084
+ object_type="group"
1085
+ )
1086
+ return {"message": "Joined group successfully"}
1087
+
1088
+ elif group.visibility == "internal":
1089
+ if len(invitations)>0:
1090
+ # Invitation found; add to group
1091
+ group.users.append(user)
1092
+ session.commit()
1093
+ else:
1094
+ # Create a JoinRequest if not already present
1095
+ existing_request = session.query(models.JoinRequest).filter(
1096
+ models.JoinRequest.user_id == user_id,
1097
+ models.JoinRequest.group_id == group_id
1098
+ ).first()
1099
+ if not existing_request:
1100
+ join_request = models.JoinRequest(
1101
+ user_id=user_id,
1102
+ group_id=group_id,
1103
+ datetime_requested=datetime.now(timezone.utc)
1104
+ )
1105
+ session.add(join_request)
1106
+ session.commit()
1107
+ notify_group_admins(group, user_id)
1108
+ return {"message": "Join request sent to group admins"}
1109
+
1110
+ elif group.visibility == "private":
1111
+ raise HTTPException(status_code=403, detail="Cannot join private group without an invite")
1112
+
1113
+
1114
+ @router.post("/create_group")
1115
+ async def create_group(
1116
+ name: str = Form(...),
1117
+ category: str = Form(None),
1118
+ description: str = Form(None),
1119
+ visibility: str = Form("public"),
1120
+ file: UploadFile = File(None), # Allow optional file upload
1121
+ current_user: models.User = Depends(get_current_user)
1122
+ ):
1123
+ with compair.Session() as session:
1124
+ print('1')
1125
+ if category not in all_group_categories()['categories']:
1126
+ category = "Other" # Default to "Other" if category is not valid
1127
+
1128
+ # Limit internal group creation to active, team plans
1129
+ if visibility == 'internal' and not (current_user.status == 'active' and _user_plan(current_user) == 'team'):
1130
+ raise HTTPException(
1131
+ status_code=403,
1132
+ detail="Internal groups can only be created by users with an active team plan"
1133
+ )
1134
+
1135
+ created_group = models.Group(
1136
+ name=name,
1137
+ group_image=None,
1138
+ category=category,
1139
+ description=description,
1140
+ visibility=visibility,
1141
+ datetime_created=datetime.now(),
1142
+ )
1143
+ print(created_group)
1144
+ # Check if user has an admin ID
1145
+ q = select(models.Administrator).filter(
1146
+ models.Administrator.user_id==current_user.user_id
1147
+ )
1148
+ admin = session.execute(q).fetchone()
1149
+ print(admin)
1150
+ if admin is None:
1151
+ # Make the user an admin
1152
+ admin = models.Administrator(user_id=current_user.user_id)
1153
+ session.add(admin)
1154
+ session.commit()
1155
+ else:
1156
+ admin = admin[0]
1157
+ print('3?')
1158
+ created_group.admins.append(admin)
1159
+ session.add(created_group)
1160
+ session.commit()
1161
+
1162
+ # Log activity
1163
+ log_activity(
1164
+ session=session,
1165
+ user_id=current_user.user_id,
1166
+ group_id=created_group.group_id,
1167
+ action="create",
1168
+ object_id=created_group.group_id,
1169
+ object_name=created_group.name,
1170
+ object_type="group"
1171
+ )
1172
+
1173
+ # Add group to user
1174
+ admin.user.groups.append(created_group)
1175
+ print('4??')
1176
+ session.add(admin)
1177
+ session.commit()
1178
+
1179
+ if file is not None:
1180
+ await upload_group_image(
1181
+ group_id=created_group.group_id,
1182
+ upload_type='group',
1183
+ file=file
1184
+ )
1185
+ return {
1186
+ "group_id": created_group.group_id,
1187
+ "name": created_group.name,
1188
+ "visibility": created_group.visibility,
1189
+ "category": created_group.category,
1190
+ }
1191
+
1192
+
1193
+ @router.get("/load_documents")
1194
+ def load_doc(
1195
+ group_id: str | None = None,
1196
+ page: int = 1,
1197
+ page_size: int = 10,
1198
+ filter_type: str | None = None,
1199
+ own_documents_only: bool = True,
1200
+ current_user: models.User = Depends(get_current_user)
1201
+ ) -> Mapping[str, Any] | None:
1202
+ with compair.Session() as session:
1203
+ now = datetime.now(timezone.utc)
1204
+ week_ago = now - timedelta(days=7)
1205
+
1206
+ # Get user and their group memberships
1207
+ user_group_ids = set(g.group_id for g in current_user.groups)
1208
+
1209
+ q = session.query(models.Document)
1210
+
1211
+ if own_documents_only:
1212
+ q = q.filter(models.Document.user_id == current_user.user_id)
1213
+ else:
1214
+ q = q.join(models.Document.groups)
1215
+ if group_id:
1216
+ q = q.filter(models.Group.group_id == group_id)
1217
+ q = q.filter(
1218
+ (models.Group.visibility == "public") |
1219
+ (
1220
+ models.Group.group_id.in_(user_group_ids)
1221
+ )
1222
+ ).options(
1223
+ joinedload(models.Document.groups),
1224
+ joinedload(models.Document.user).joinedload(models.User.groups)
1225
+ )
1226
+
1227
+ if group_id is not None and own_documents_only:
1228
+ q = q.filter(models.Document.groups.any(models.Group.group_id == group_id))
1229
+
1230
+ # --- Filter logic: publishing ---
1231
+ if filter_type == "unpublished" and own_documents_only:
1232
+ q = q.filter(models.Document.is_published == False)
1233
+ else:
1234
+ q = q.filter(
1235
+ or_(
1236
+ models.Document.is_published == True,
1237
+ models.Document.user_id == current_user.user_id
1238
+ )
1239
+ )
1240
+
1241
+ # --- Filter logic: other ---
1242
+ if filter_type == "recently_updated":
1243
+ # Documents updated OR with a note in the past week
1244
+ q = q.outerjoin(models.Document.notes).filter(
1245
+ or_(
1246
+ models.Document.datetime_modified >= week_ago,
1247
+ models.Note.datetime_created >= week_ago
1248
+ )
1249
+ )
1250
+ elif filter_type == "recently_compaired":
1251
+ # Documents with feedback in the past week
1252
+ q = q.join(models.Document.chunks).join(models.Chunk.feedbacks).filter(
1253
+ models.Feedback.timestamp >= week_ago
1254
+ )
1255
+ # Default: all, sorted by last update
1256
+ q = q.order_by(models.Document.datetime_modified.desc())
1257
+
1258
+ documents = session.execute(q).unique().fetchall()
1259
+ print(documents)
1260
+ if documents is None or len(documents)==0:
1261
+ return {
1262
+ "documents": [],
1263
+ "total_count": 0
1264
+ }
1265
+
1266
+ total_count = q.count()
1267
+
1268
+ # Paging
1269
+ offset = (page - 1) * page_size
1270
+ documents = session.execute(q.order_by(models.Document.datetime_created.desc()).offset(offset).limit(page_size)).unique().fetchall()
1271
+ #print(documents)
1272
+ if documents is None or len(documents)==0:
1273
+ return {
1274
+ "documents": [],
1275
+ "total_count": 0
1276
+ }
1277
+ #print(f'API returning these documents: {documents}')
1278
+ print(f'Total count: {total_count}')
1279
+ print(f'Page: {page}, Page size: {page_size}, Offset: {offset}')
1280
+ return {
1281
+ "documents": [
1282
+ schema.Document.model_validate(d[0]) for d in documents
1283
+ ],
1284
+ "total_count": total_count
1285
+ }
1286
+
1287
+
1288
+ @router.get("/load_group_users")
1289
+ def load_group_users(
1290
+ group_id: str,
1291
+ page: int = 1,
1292
+ page_size: int = 10,
1293
+ filter_type: str | None = None,
1294
+ current_user: models.User = Depends(get_current_user)
1295
+ ) -> dict:
1296
+ print('1')
1297
+ print(current_user.user_id)
1298
+ print(group_id)
1299
+ with compair.Session() as session:
1300
+ # Check if the group exists
1301
+ group = session.query(models.Group).filter(models.Group.group_id == group_id).first()
1302
+ print(group)
1303
+ if not group:
1304
+ raise HTTPException(status_code=404, detail="Group not found")
1305
+
1306
+ if (current_user not in group.users) & (group.visibility!='public'):
1307
+ raise HTTPException(status_code=403, detail="User does not belong to the group")
1308
+
1309
+ # Retrieve all users associated with the group
1310
+ users_query = session.query(models.User).join(models.User.groups).filter(models.Group.group_id == group_id)
1311
+
1312
+ # --- Filter logic ---
1313
+ if filter_type == "recently_active":
1314
+ users_query = users_query.order_by(models.User.status_change_date.desc())
1315
+ elif filter_type == "recently_joined":
1316
+ users_query = users_query.order_by(models.User.datetime_registered.desc())
1317
+ else:
1318
+ users_query = users_query.order_by(models.User.name.asc())
1319
+
1320
+ total_count = users_query.count()
1321
+ offset = (page - 1) * page_size
1322
+ users = users_query.offset(offset).limit(page_size).all()
1323
+
1324
+ # Convert users to schema objects
1325
+ return {
1326
+ "users": [
1327
+ {
1328
+ "user_id": u.user_id,
1329
+ "username": u.username,
1330
+ "name": u.name,
1331
+ "datetime_registered": u.datetime_registered,
1332
+ "status": u.status,
1333
+ "profile_image": u.profile_image,
1334
+ "role": u.role,
1335
+ }
1336
+ for u in users
1337
+ ],
1338
+ "total_count": total_count
1339
+ }
1340
+
1341
+
1342
+ @router.get("/load_document")
1343
+ def load_doc(
1344
+ title: str,
1345
+ current_user: models.User = Depends(get_current_user)
1346
+ ) -> schema.Document | None:
1347
+ with compair.Session() as session:
1348
+ q = select(models.Document).filter(
1349
+ models.Document.user_id==current_user.user_id
1350
+ ).filter(
1351
+ models.Document.title.match(title)
1352
+ ).options(
1353
+ joinedload(models.Document.groups),
1354
+ joinedload(models.Document.user).joinedload(models.User.groups)
1355
+ )
1356
+ document = session.execute(q).unique().fetchone()
1357
+ if document is None:
1358
+ return None
1359
+ return document[0]
1360
+
1361
+
1362
+ @router.get("/load_document_by_id")
1363
+ def load_doc(
1364
+ document_id: str,
1365
+ current_user: models.User = Depends(get_current_user)
1366
+ ) -> schema.Document | None:
1367
+ with compair.Session() as session:
1368
+ q = select(models.Document).filter(
1369
+ models.Document.document_id==document_id
1370
+ ).options(
1371
+ joinedload(models.Document.groups),
1372
+ joinedload(models.Document.user).joinedload(models.User.groups)
1373
+ )
1374
+ document = session.execute(q).unique().fetchone()
1375
+ if document is None:
1376
+ return None
1377
+ doc = document[0]
1378
+ doc_group_ids = {g.group_id for g in doc.groups}
1379
+ user_group_ids = {g.group_id for g in current_user.groups}
1380
+ if not doc_group_ids & user_group_ids and current_user.user_id != doc.author_id:
1381
+ raise HTTPException(status_code=403, detail="Not authorized to view this document")
1382
+ return doc
1383
+
1384
+
1385
+ @router.post("/update_doc")
1386
+ def update_doc(
1387
+ doc_id: str = Form(...),
1388
+ author_id: str = Form(None),
1389
+ title: str = Form(None),
1390
+ datetime_created: datetime = Form(None),
1391
+ group_ids: list[str] = Form(None),
1392
+ image_url: str = Form(None),
1393
+ current_user: models.User = Depends(get_current_user)
1394
+ ):
1395
+ print('In update doc')
1396
+ print(author_id)
1397
+ print(title)
1398
+ print(datetime_created)
1399
+ print(group_ids)
1400
+ print(image_url)
1401
+ with compair.Session() as session:
1402
+ doc = session.query(models.Document).filter(
1403
+ models.Document.document_id == doc_id,
1404
+ models.Document.user_id == current_user.user_id
1405
+ ).first()
1406
+ if doc:
1407
+ if author_id is not None:
1408
+ doc.author_id = author_id
1409
+ if title is not None:
1410
+ doc.title = title
1411
+ if datetime_created is not None:
1412
+ doc.datetime_created = datetime_created
1413
+ if group_ids is not None:
1414
+ groups = load_groups_by_ids(group_ids)
1415
+ doc.groups = []
1416
+ doc.groups.extend(groups)
1417
+ print(f'New groups here? {doc.groups}')
1418
+ if image_url is not None:
1419
+ doc.image_url = image_url
1420
+ session.commit()
1421
+
1422
+
1423
+ @router.post("/create_doc")
1424
+ def create_doc(
1425
+ authorid: str = Form(None),
1426
+ document_title: str = Form(None),
1427
+ document_type: str = Form(None),
1428
+ document_content: str = Form(""),
1429
+ groups: str = Form(None), # TODO: Fix how these get submitted; current comma-separated list string
1430
+ is_published: bool = Form(False),
1431
+ current_user: models.User = Depends(get_current_user),
1432
+ analytics: Analytics = Depends(get_analytics),
1433
+ ):
1434
+ with compair.Session() as session:
1435
+ # Check if the trial has expired
1436
+ current_user = session.query(models.User).filter(models.User.user_id == current_user.user_id).first()
1437
+ trial_expiration = _trial_expiration(current_user)
1438
+ if HAS_TRIALS and trial_expiration and current_user.status == "trial" and trial_expiration < datetime.now(timezone.utc):
1439
+ current_user.status = "suspended" # Mark as suspended once the trial expires
1440
+ current_user.status_change_date = datetime.now(timezone.utc)
1441
+ session.commit()
1442
+
1443
+ # Enforce document limits (cloud plans) – core runs are unrestricted unless explicitly configured
1444
+ team = _user_team(current_user)
1445
+ document_limit: int | None = None
1446
+ if IS_CLOUD and HAS_TEAM and team and current_user.status == "active":
1447
+ document_limit = team.total_documents_limit # type: ignore[union-attr]
1448
+ elif IS_CLOUD and _user_plan(current_user) == "individual" and current_user.status == "active":
1449
+ document_limit = 100
1450
+ else:
1451
+ raw_core_limit = os.getenv("COMPAIR_CORE_DOCUMENT_LIMIT")
1452
+ if raw_core_limit:
1453
+ try:
1454
+ document_limit = int(raw_core_limit)
1455
+ except ValueError:
1456
+ document_limit = None
1457
+
1458
+ document_count = session.query(models.Document).filter(models.Document.user_id == current_user.user_id).count()
1459
+
1460
+ if document_limit is not None and document_count >= document_limit:
1461
+ if IS_CLOUD:
1462
+ detail_msg = (
1463
+ "Document limit reached. Individual plan users can have 100, team plans have 100 times "
1464
+ "the number of users (pooled); other plans can have 10"
1465
+ )
1466
+ else:
1467
+ detail_msg = (
1468
+ f"Document limit of {document_limit} reached. Adjust COMPAIR_CORE_DOCUMENT_LIMIT to raise "
1469
+ "or unset it to remove limits in core deployments."
1470
+ )
1471
+ raise HTTPException(status_code=403, detail=detail_msg)
1472
+
1473
+ if not authorid:
1474
+ authorid = current_user.user_id
1475
+
1476
+ document = models.Document(
1477
+ user_id=current_user.user_id,
1478
+ author_id=authorid,
1479
+ title=document_title,
1480
+ content=document_content,
1481
+ doc_type=document_type,
1482
+ datetime_created=datetime.now(timezone.utc),
1483
+ datetime_modified=datetime.now(timezone.utc)
1484
+ )
1485
+ print('About to assign groups!')
1486
+ target_group_ids = []
1487
+ if groups:
1488
+ target_group_ids = [gid.strip() for gid in groups.split(',') if gid.strip()]
1489
+
1490
+ if target_group_ids:
1491
+ q = select(models.Group).filter(models.Group.group_id.in_(target_group_ids))
1492
+ resolved_groups = session.execute(q).scalars().all()
1493
+ if not resolved_groups:
1494
+ raise HTTPException(status_code=404, detail="No matching groups found for provided IDs.")
1495
+ document.groups = resolved_groups
1496
+ else:
1497
+ q = select(models.Group).filter(models.Group.name == current_user.username)
1498
+ default_group = session.execute(q).scalars().first()
1499
+ if default_group is None:
1500
+ raise HTTPException(status_code=404, detail="Default group not found for user.")
1501
+ document.groups = [default_group]
1502
+
1503
+ primary_group = document.groups[0]
1504
+
1505
+ print(f'doc check!!! {document.content}')
1506
+ session.add(document)
1507
+ session.commit()
1508
+
1509
+ if is_published:
1510
+ # Attempt to publish doc, pending user status check
1511
+ publish_doc(
1512
+ doc_id=document.document_id,
1513
+ is_published=True,
1514
+ current_user=current_user
1515
+ )
1516
+
1517
+ # Log document creation
1518
+ log_activity(
1519
+ session=session,
1520
+ user_id=document.author_id,
1521
+ group_id=primary_group.group_id,
1522
+ action="create",
1523
+ object_id=document.document_id,
1524
+ object_name=document.title,
1525
+ object_type="document"
1526
+ )
1527
+ # Return the document_id for frontend use
1528
+ try:
1529
+ analytics.track("document_created", document.user_id)
1530
+ except Exception as exc:
1531
+ print(f"analytics track failed: {exc}")
1532
+ return {"document_id": document.document_id}
1533
+
1534
+
1535
+ @router.get("/publish_doc")
1536
+ def publish_doc(
1537
+ doc_id: str,
1538
+ is_published: bool,
1539
+ current_user: models.User = Depends(get_current_user)
1540
+ ):
1541
+ with compair.Session() as session:
1542
+ # Check if user can publish or not
1543
+ if current_user.status == "suspended":
1544
+ raise HTTPException(status_code=403, detail="Your subscription has expired. Renew to publish.")
1545
+
1546
+ doc = session.query(models.Document).filter(
1547
+ models.Document.document_id == doc_id,
1548
+ models.Document.user_id == current_user.user_id
1549
+ )
1550
+ if not doc.first():
1551
+ raise HTTPException(status_code=404, detail="Document not found or you do not have permission to publish it.")
1552
+ if is_published!=doc.first().is_published:
1553
+ doc.update({'is_published': is_published})
1554
+ session.commit()
1555
+
1556
+
1557
+ @router.get("/delete_docs")
1558
+ def delete_docs(
1559
+ doc_ids: str,
1560
+ current_user: models.User = Depends(get_current_user)
1561
+ ):
1562
+ with compair.Session() as session:
1563
+ doc_ids = doc_ids.split(',')
1564
+ documents = session.query(models.Document).filter(
1565
+ models.Document.document_id.in_(doc_ids),
1566
+ models.Document.user_id == current_user.user_id
1567
+ )
1568
+
1569
+ for document in documents:
1570
+ doc_id = document.document_id
1571
+ doc_name = document.title
1572
+ doc_group: models.Group = document.groups[0]
1573
+ group_id = doc_group.group_id
1574
+ log_activity(
1575
+ session=session,
1576
+ user_id=current_user.user_id,
1577
+ group_id=group_id,
1578
+ action="delete",
1579
+ object_id=doc_id,
1580
+ object_name=doc_name,
1581
+ object_type="document"
1582
+ )
1583
+
1584
+ documents.delete()
1585
+ session.commit()
1586
+
1587
+
1588
+ @router.get("/delete_doc")
1589
+ def delete_doc(
1590
+ doc_id: str,
1591
+ current_user: models.User = Depends(get_current_user)
1592
+ ):
1593
+ with compair.Session() as session:
1594
+ document = session.query(models.Document).filter(
1595
+ models.Document.document_id == doc_id,
1596
+ models.Document.user_id == current_user.user_id
1597
+ )
1598
+ if not document.first():
1599
+ raise HTTPException(status_code=404, detail="Document not found or you do not have permission to delete it.")
1600
+
1601
+ doc_group = document[0].groups[0]
1602
+ group_id = doc_group.group_id
1603
+ doc_name = document[0].title
1604
+
1605
+ document.delete()
1606
+ session.commit()
1607
+
1608
+ log_activity(
1609
+ session=session,
1610
+ user_id=current_user.user_id,
1611
+ group_id=group_id,
1612
+ action="delete",
1613
+ object_id=doc_id,
1614
+ object_name=doc_name,
1615
+ object_type="document"
1616
+ )
1617
+
1618
+
1619
+ @router.post("/process_doc")
1620
+ async def process_doc(
1621
+ doc_id: str = Form(...),
1622
+ doc_text: str = Form(...),
1623
+ generate_feedback: bool = Form(True),
1624
+ current_user: models.User = Depends(get_current_user),
1625
+ analytics: Analytics = Depends(get_analytics),
1626
+ ) -> Mapping[str, str | None]:
1627
+ with compair.Session() as session:
1628
+ doc = session.query(models.Document).filter(models.Document.document_id == doc_id).first()
1629
+ if not doc:
1630
+ raise HTTPException(status_code=404, detail="Document not found")
1631
+ # Only allow the author to process/edit
1632
+ if doc.author_id != current_user.user_id:
1633
+ raise HTTPException(status_code=403, detail="Only the author can edit this document")
1634
+ # If the user is suspended, allow user to edit, but not receive any new feedback for docs
1635
+ if current_user.status == "suspended":
1636
+ generate_feedback=False
1637
+
1638
+ task_result = _dispatch_process_document_task(
1639
+ user_id=current_user.user_id,
1640
+ doc_id=doc_id,
1641
+ doc_text=doc_text,
1642
+ generate_feedback=generate_feedback,
1643
+ )
1644
+ task_id = getattr(task_result, "id", None)
1645
+
1646
+ if generate_feedback:
1647
+ try:
1648
+ analytics.track("feedback_requested", current_user.user_id)
1649
+ except Exception as exc:
1650
+ print(f"analytics track failed: {exc}")
1651
+ return {"task_id": task_id}
1652
+
1653
+
1654
+ @router.post("/upload/ocr-file")
1655
+ async def upload_ocr_file(
1656
+ file: UploadFile = File(...),
1657
+ document_id: str = Form(None),
1658
+ current_user: models.User = Depends(get_current_user),
1659
+ settings: Settings = Depends(get_settings_dependency),
1660
+ ocr: OCRProvider = Depends(get_ocr),
1661
+ ):
1662
+ if not settings.ocr_enabled:
1663
+ raise HTTPException(status_code=501, detail="OCR is not available in this edition.")
1664
+
1665
+ file_bytes = await file.read()
1666
+ try:
1667
+ task_id = ocr.submit(
1668
+ user_id=current_user.user_id,
1669
+ filename=file.filename or "upload",
1670
+ data=file_bytes,
1671
+ document_id=document_id,
1672
+ )
1673
+ except NotImplementedError as exc:
1674
+ raise HTTPException(status_code=501, detail=str(exc)) from exc
1675
+ except Exception as exc:
1676
+ raise HTTPException(status_code=500, detail=str(exc)) from exc
1677
+
1678
+ return {"task_id": task_id}
1679
+
1680
+
1681
+ @router.get("/ocr-file-result/{task_id}")
1682
+ def get_ocr_file_result(
1683
+ task_id: str,
1684
+ current_user: models.User = Depends(get_current_user),
1685
+ settings: Settings = Depends(get_settings_dependency),
1686
+ ocr: OCRProvider = Depends(get_ocr),
1687
+ ):
1688
+ if not settings.ocr_enabled:
1689
+ raise HTTPException(status_code=501, detail="OCR is not available in this edition.")
1690
+
1691
+ try:
1692
+ status = ocr.status(task_id)
1693
+ except NotImplementedError as exc:
1694
+ raise HTTPException(status_code=501, detail=str(exc)) from exc
1695
+ except Exception as exc:
1696
+ raise HTTPException(status_code=500, detail=str(exc)) from exc
1697
+
1698
+ if isinstance(status, dict):
1699
+ payload = {"task_id": task_id, **status}
1700
+ else:
1701
+ payload = {"task_id": task_id, "status": status}
1702
+
1703
+ return payload
1704
+
1705
+
1706
+ @router.get("/status/{task_id}")
1707
+ async def get_process_status(
1708
+ task_id: str
1709
+ ):
1710
+ task_result = AsyncResult(task_id)
1711
+ print(task_result)
1712
+ print(task_result.status)
1713
+ if task_result.status == "SUCCESS":
1714
+ result = task_result.result
1715
+ elif task_result.status == "PENDING":
1716
+ result = "Task is still processing."
1717
+ else:
1718
+ result = "Task failed."
1719
+ return {"task_id": task_id, "status": task_result.status, "result": result}
1720
+
1721
+
1722
+ @router.get("/trial_status")
1723
+ def get_trial_status(
1724
+ current_user: models.User = Depends(get_current_user),
1725
+ billing: BillingProvider = Depends(get_billing),
1726
+ settings: Settings = Depends(get_settings_dependency),
1727
+ ):
1728
+ """Return the trial status for the authenticated user."""
1729
+ print('Trial 1')
1730
+ require_feature(HAS_TRIALS, "Trial management")
1731
+ with compair.Session() as session:
1732
+ trial_expiration = getattr(current_user, "trial_expiration_date", None)
1733
+ if trial_expiration is None:
1734
+ raise HTTPException(status_code=404, detail="Trial expiration date not found.")
1735
+ days_left = (trial_expiration - datetime.now(timezone.utc)).days
1736
+ show_banner = 0 < days_left <= 7
1737
+
1738
+ checkout_url = None
1739
+ if show_banner and settings.billing_enabled:
1740
+ require_feature(HAS_BILLING, "Billing integration")
1741
+ db_user = session.query(models.User).filter(models.User.user_id == current_user.user_id).first()
1742
+ if db_user:
1743
+ if not db_user.stripe_customer_id:
1744
+ try:
1745
+ customer_id = billing.ensure_customer(
1746
+ user_email=db_user.username,
1747
+ user_id=db_user.user_id,
1748
+ )
1749
+ except NotImplementedError:
1750
+ customer_id = None
1751
+ else:
1752
+ db_user.stripe_customer_id = customer_id
1753
+ session.commit()
1754
+ customer_id = db_user.stripe_customer_id
1755
+ if customer_id:
1756
+ try:
1757
+ session_info = billing.create_checkout_session(
1758
+ customer_id=customer_id,
1759
+ price_id=plan_to_id.get("individual_monthly"),
1760
+ qty=1,
1761
+ success_url=settings.stripe_success_url,
1762
+ cancel_url=settings.stripe_cancel_url,
1763
+ metadata={"plan": "individual_monthly", "user_id": db_user.user_id},
1764
+ )
1765
+ checkout_url = getattr(session_info, "url", None)
1766
+ if checkout_url is None and hasattr(session_info, "get"):
1767
+ checkout_url = session_info.get("url")
1768
+ except NotImplementedError:
1769
+ checkout_url = None
1770
+ except Exception as exc:
1771
+ print(f"Error generating Checkout URL: {exc}")
1772
+
1773
+ return {
1774
+ "status": current_user.status,
1775
+ "days_left": days_left,
1776
+ "show_banner": show_banner,
1777
+ "checkout_url": checkout_url,
1778
+ }
1779
+
1780
+
1781
+ @router.get("/load_chunks")
1782
+ def load_chunks(
1783
+ document_id: str,
1784
+ ) -> list[schema.Chunk]:
1785
+ with compair.Session() as session:
1786
+ chunks = session.query(models.Chunk).filter(
1787
+ models.Chunk.document_id==document_id
1788
+ )
1789
+ return chunks
1790
+
1791
+
1792
+ @router.get("/load_feedback")
1793
+ def load_feedback(
1794
+ chunk_id: str,
1795
+ ) -> schema.Feedback | None:
1796
+ with compair.Session() as session:
1797
+ feedback = session.query(models.Feedback).filter(
1798
+ models.Feedback.source_chunk_id==chunk_id
1799
+ )
1800
+ if len(feedback.all())>0:
1801
+ f = feedback[0]
1802
+ # Return user_feedback in the response
1803
+ return {
1804
+ "feedback_id": f.feedback_id,
1805
+ "source_chunk_id": f.source_chunk_id,
1806
+ "feedback": f.feedback,
1807
+ "user_feedback": f.user_feedback,
1808
+ "is_hidden": f.is_hidden,
1809
+ }
1810
+ return None
1811
+
1812
+
1813
+ @router.post("/feedback/{feedback_id}/hide")
1814
+ def hide_feedback(
1815
+ feedback_id: str,
1816
+ is_hidden: bool = Form(False)
1817
+ ):
1818
+ with compair.Session() as session:
1819
+ feedback = session.query(models.Feedback).filter(models.Feedback.feedback_id == feedback_id).first()
1820
+ if not feedback:
1821
+ raise HTTPException(status_code=404, detail="Feedback not found")
1822
+ feedback.is_hidden = is_hidden
1823
+ session.commit()
1824
+ return {"message": f"Feedback {'hidden' if is_hidden else 'unhidden'} successfully", "feedback_id": feedback_id}
1825
+
1826
+
1827
+
1828
+ @router.post("/feedback/{feedback_id}/rate")
1829
+ def rate_feedback(
1830
+ feedback_id: str,
1831
+ user_feedback: str = Body(..., embed=True) # expects "positive" or "negative"
1832
+ ):
1833
+ if user_feedback not in ("positive", "negative", None):
1834
+ raise HTTPException(status_code=400, detail="Invalid feedback label")
1835
+ with compair.Session() as session:
1836
+ feedback = session.query(models.Feedback).filter(models.Feedback.feedback_id == feedback_id).first()
1837
+ if not feedback:
1838
+ raise HTTPException(status_code=404, detail="Feedback not found")
1839
+ feedback.user_feedback = user_feedback
1840
+ session.commit()
1841
+ return {"message": "Feedback rated successfully", "feedback_id": feedback_id, "user_feedback": user_feedback}
1842
+
1843
+
1844
+ @router.get("/documents/{document_id}/feedback")
1845
+ def list_document_feedback(
1846
+ document_id: str,
1847
+ current_user: models.User = Depends(get_current_user)
1848
+ ):
1849
+ """Return feedback entries for a document the current user can access."""
1850
+ with compair.Session() as session:
1851
+ doc = session.query(models.Document).filter(models.Document.document_id == document_id).first()
1852
+ if not doc:
1853
+ raise HTTPException(status_code=404, detail="Document not found")
1854
+ # Access control: owner, member of a group with access, or published
1855
+ user_group_ids = {g.group_id for g in current_user.groups}
1856
+ doc_group_ids = {g.group_id for g in doc.groups}
1857
+ if not (doc.user_id == current_user.user_id or doc_group_ids & user_group_ids or doc.is_published):
1858
+ raise HTTPException(status_code=403, detail="Not authorized to access this document")
1859
+ # Join through chunks
1860
+ q = (
1861
+ session.query(models.Feedback)
1862
+ .join(models.Chunk, models.Feedback.source_chunk_id == models.Chunk.chunk_id)
1863
+ .filter(models.Chunk.document_id == document_id)
1864
+ .order_by(models.Feedback.timestamp.desc())
1865
+ )
1866
+ rows = q.all()
1867
+ return {
1868
+ "document_id": document_id,
1869
+ "count": len(rows),
1870
+ "feedback": [
1871
+ {
1872
+ "feedback_id": f.feedback_id,
1873
+ "chunk_id": f.source_chunk_id,
1874
+ "feedback": f.feedback,
1875
+ "user_feedback": f.user_feedback,
1876
+ "timestamp": f.timestamp,
1877
+ } for f in rows
1878
+ ],
1879
+ }
1880
+
1881
+
1882
+ @router.get("/load_references")
1883
+ def load_references(
1884
+ chunk_id: str,
1885
+ ) -> list[schema.Reference]:
1886
+ with compair.Session() as session:
1887
+ references = session.query(models.Reference).filter(
1888
+ models.Reference.source_chunk_id==chunk_id
1889
+ ).all()
1890
+ returned_references: list[schema.Reference] = [
1891
+ schema.Reference(
1892
+ reference_id=r.reference_id,
1893
+ source_chunk_id=r.source_chunk_id,
1894
+ reference_document_id=r.reference_document_id,
1895
+ document=schema.Document(
1896
+ document_id=r.document.document_id,
1897
+ user_id=r.document.user_id,
1898
+ author_id=r.document.author_id,
1899
+ title=r.document.title,
1900
+ content=r.document.content,
1901
+ doc_type=r.document.doc_type,
1902
+ datetime_created=r.document.datetime_created,
1903
+ datetime_modified=r.document.datetime_modified,
1904
+ is_published=r.document.is_published,
1905
+ groups=[{"group_id":"","datetime_created":datetime.now(),"name":""}],
1906
+ user=schema.User(
1907
+ user_id=r.document.user.user_id,
1908
+ username=r.document.user.username,
1909
+ name=r.document.user.name,
1910
+ groups=[{"group_id":"","datetime_created":datetime.now(),"name":""}],
1911
+ datetime_registered=r.document.user.datetime_registered,
1912
+ status=r.document.user.status
1913
+ )
1914
+ ),
1915
+ document_author=r.document.user.username
1916
+ ) for r in references
1917
+ ]
1918
+ print(f'Returned references: {returned_references}')
1919
+ return returned_references
1920
+
1921
+ @router.get("/verify-email")
1922
+ def verify_email(token: str):
1923
+ settings = get_settings_dependency()
1924
+ if not settings.require_authentication:
1925
+ raise HTTPException(status_code=403, detail="Email verification is disabled when authentication is disabled.")
1926
+ with compair.Session() as session:
1927
+ print(token)
1928
+ user = session.query(models.User).filter(models.User.verification_token == token).first()
1929
+ print(user)
1930
+ print(user.token_expiration)
1931
+ print(datetime.now(timezone.utc))
1932
+ if not user:
1933
+ raise HTTPException(status_code=400, detail="Invalid or expired token")
1934
+ if user.token_expiration < datetime.now(timezone.utc):
1935
+ raise HTTPException(status_code=400, detail="Token has expired")
1936
+ _activate_user_account(session, user, send_group_invites=True)
1937
+ return {"message": "Email verified successfully. Your free trial has started!"}
1938
+
1939
+ def is_valid_email(email):
1940
+ return re.match(r"[^@]+@[^@]+\.[^@]+", email)
1941
+
1942
+ @router.post("/sign-up")
1943
+ def sign_up(
1944
+ request: schema.SignUpRequest,
1945
+ analytics: Analytics = Depends(get_analytics),
1946
+ ) -> dict:
1947
+ settings = get_settings_dependency()
1948
+ if not settings.require_authentication:
1949
+ raise HTTPException(status_code=403, detail="Sign-up is disabled when authentication is disabled.")
1950
+ print('1')
1951
+ if not is_valid_email(request.username):
1952
+ raise HTTPException(status_code=400, detail="Invalid email address")
1953
+ with compair.Session() as session:
1954
+ print('2')
1955
+ # Call internal function to create the user
1956
+ #user = create_user(email=request.email, password=request.password)
1957
+ user = create_user(
1958
+ username=request.username,
1959
+ name=request.name,
1960
+ password=request.password,
1961
+ groups=request.groups,
1962
+ session=session,
1963
+ referral_code=request.referral_code,
1964
+ )
1965
+ print('Passed create_user')
1966
+ if settings.require_email_verification:
1967
+ verification_link = f"http://{WEB_URL}/verify-email?token={user.verification_token}"
1968
+ print('3?')
1969
+ emailer.connect()
1970
+ print('4??')
1971
+ emailer.send(
1972
+ subject="Verify your email address",
1973
+ sender=EMAIL_USER,
1974
+ receivers=[user.username],
1975
+ html=_render_email(
1976
+ ACCOUNT_VERIFY_TEMPLATE,
1977
+ verification_link=verification_link,
1978
+ user_name=user.name or user.username or "there",
1979
+ ),
1980
+ )
1981
+ print('The end???')
1982
+ return {"message": "Sign-up successful. Please check your email for verification."}
1983
+ _activate_user_account(session, user, send_group_invites=False)
1984
+ return {"message": "Sign-up successful. Your account is ready to use."}
1985
+
1986
+ @router.post("/forgot-password")
1987
+ def forgot_password(request: schema.ForgotPasswordRequest) -> dict:
1988
+ settings = get_settings_dependency()
1989
+ if not settings.require_authentication:
1990
+ raise HTTPException(status_code=403, detail="Password resets are disabled when authentication is disabled.")
1991
+ print('1')
1992
+ with compair.Session() as session:
1993
+ print('2')
1994
+ user = session.query(models.User).filter(models.User.username == request.email).first()
1995
+ print(user)
1996
+ if not user:
1997
+ return {"message": "If the email exists, a reset link will be sent."}
1998
+
1999
+ # Generate reset token
2000
+ token, expiration = generate_verification_token() # Same function as before
2001
+ print(token)
2002
+ token = token.lower()
2003
+ user.reset_token = token
2004
+ user.token_expiration = expiration
2005
+ session.commit()
2006
+
2007
+ print('3')
2008
+ # Send email with reset link
2009
+ reset_link = f"http://{WEB_URL}/reset-password?token={token}"
2010
+ emailer.connect()
2011
+ print('4')
2012
+ emailer.send(
2013
+ subject="Password Reset Request",
2014
+ sender=EMAIL_USER,
2015
+ receivers=[request.email],
2016
+ html=_render_email(
2017
+ PASSWORD_RESET_TEMPLATE,
2018
+ reset_link=reset_link,
2019
+ user_name=user.name or user.username or "",
2020
+ ),
2021
+ )
2022
+ print('5')
2023
+ return {"message": "If the email exists, a reset link will be sent."}
2024
+
2025
+ @router.post("/reset-password")
2026
+ def reset_password(request: schema.ResetPasswordRequest) -> dict:
2027
+ settings = get_settings_dependency()
2028
+ if not settings.require_authentication:
2029
+ raise HTTPException(status_code=403, detail="Password resets are disabled when authentication is disabled.")
2030
+ with compair.Session() as session:
2031
+ print('1')
2032
+ print(request.token)
2033
+ user = session.query(models.User).filter(models.User.reset_token == request.token).first()
2034
+ print(user)
2035
+ if not user or user.token_expiration < datetime.now(timezone.utc):
2036
+ raise HTTPException(status_code=400, detail="Invalid or expired token")
2037
+
2038
+ # Update the password
2039
+ user.set_password(request.new_password)
2040
+ print('2')
2041
+ user.reset_token = None # Invalidate the token
2042
+ user.token_expiration = None
2043
+ print('3')
2044
+ session.commit()
2045
+ print('4')
2046
+
2047
+ return {"message": "Password has been reset successfully"}
2048
+
2049
+ @router.get("/admin/groups")
2050
+ def get_admin_groups(
2051
+ current_user: models.User = Depends(get_current_user),
2052
+ ):
2053
+ """Retrieve groups managed by the given user (admin)."""
2054
+ with compair.Session() as session:
2055
+ admin = session.query(models.Administrator).filter(models.Administrator.user_id == current_user.user_id).first()
2056
+ if not admin:
2057
+ return [] # Not an admin of any groups
2058
+
2059
+ # Only return groups where this user is an admin
2060
+ groups = admin.groups
2061
+ return [
2062
+ {
2063
+ "group_id": g.group_id,
2064
+ "name": g.name,
2065
+ "visibility": g.visibility,
2066
+ "category": g.category,
2067
+ "description": g.description,
2068
+ "group_image": g.group_image,
2069
+ "datetime_created": g.datetime_created,
2070
+ }
2071
+ for g in groups
2072
+ ]
2073
+
2074
+ @router.get("/admin/join_requests")
2075
+ def get_join_requests(
2076
+ group_id: str,
2077
+ current_user: models.User = Depends(get_current_user)
2078
+ ):
2079
+ """Retrieve pending join requests for a group."""
2080
+ with compair.Session() as session:
2081
+ requests = session.query(models.JoinRequest).filter(
2082
+ models.JoinRequest.group_id == group_id
2083
+ ).filter(
2084
+ models.JoinRequest.group.has(models.Group.admins.any(models.Administrator.user_id == current_user.user_id))
2085
+ ).all()
2086
+ return [
2087
+ {"request_id": r.request_id, "user_name": r.user.name, "datetime_requested": r.datetime_requested}
2088
+ for r in requests
2089
+ ]
2090
+
2091
+ @router.post("/admin/approve_request")
2092
+ def approve_request(
2093
+ request_id: int,
2094
+ current_user: models.User = Depends(get_current_user)
2095
+ ):
2096
+ """Approve a join request."""
2097
+ with compair.Session() as session:
2098
+ request = session.query(models.JoinRequest).filter(
2099
+ models.JoinRequest.request_id == request_id
2100
+ ).filter(
2101
+ models.JoinRequest.group.has(models.Group.admins.any(models.Administrator.user_id == current_user.user_id))
2102
+ ).first()
2103
+ if not request:
2104
+ raise HTTPException(status_code=404, detail="Request not found")
2105
+
2106
+ group: models.Group = request.group
2107
+ user: models.User = request.user
2108
+ group.users.append(user)
2109
+ session.delete(request)
2110
+ session.commit()
2111
+
2112
+ log_activity(
2113
+ session=session,
2114
+ user_id=user.user_id,
2115
+ group_id=group.group_id,
2116
+ action="join",
2117
+ object_id=group.group_id,
2118
+ object_name=group.name,
2119
+ object_type="group"
2120
+ )
2121
+
2122
+ return {"message": "Request approved successfully"}
2123
+
2124
+ @router.post("/admin/reject_request")
2125
+ def reject_request(
2126
+ request_id: int,
2127
+ current_user: models.User = Depends(get_current_user)
2128
+ ):
2129
+ """Reject a join request."""
2130
+ with compair.Session() as session:
2131
+ request = session.query(models.JoinRequest).filter(
2132
+ models.JoinRequest.request_id == request_id,
2133
+ models.JoinRequest.group.has(models.Group.admins.any(models.Administrator.user_id == current_user.user_id))
2134
+ ).first()
2135
+ if not request:
2136
+ raise HTTPException(status_code=404, detail="Request not found")
2137
+
2138
+ session.delete(request)
2139
+ session.commit()
2140
+
2141
+ return {"message": "Request rejected successfully"}
2142
+
2143
+ @router.post("/admin/update_group")
2144
+ def update_group(
2145
+ group_id: str,
2146
+ name: Optional[str] = Form(None),
2147
+ visibility: Optional[str] = Form(None),
2148
+ category: Optional[str] = Form(None),
2149
+ description: Optional[str] = Form(None),
2150
+ current_user: models.User = Depends(get_current_user)
2151
+ ) -> dict:
2152
+ """Update group settings."""
2153
+ print('1')
2154
+ with compair.Session() as session:
2155
+ group = session.query(models.Group).filter(
2156
+ models.Group.group_id == group_id,
2157
+ models.Group.admins.any(models.Administrator.user_id == current_user.user_id)
2158
+ ).first()
2159
+ print('2')
2160
+ if not group:
2161
+ raise HTTPException(status_code=404, detail="Group not found")
2162
+ print(group)
2163
+ print(group_id)
2164
+ print(name)
2165
+ print(visibility)
2166
+ print(category)
2167
+ print(description)
2168
+ if name:
2169
+ group.name = name
2170
+ if visibility:
2171
+ group.visibility = visibility
2172
+ if category:
2173
+ group.category = category
2174
+ if description:
2175
+ group.description = description
2176
+ session.commit()
2177
+
2178
+ return {"message": "Group updated successfully"}
2179
+
2180
+ @router.post("/admin/invite_member")
2181
+ def invite_member_to_group(
2182
+ request: schema.InviteMemberRequest,
2183
+ current_user: models.User = Depends(get_current_user)
2184
+ ):
2185
+ """
2186
+ Invite an existing Compair user (by username or email) to join a group.
2187
+ If the user exists, create a GroupInvitation and send an email.
2188
+ """
2189
+ admin_id = request.admin_id,
2190
+ group_id = request.group_id,
2191
+ username = request.username
2192
+ with compair.Session() as session:
2193
+ group = session.query(models.Group).filter(
2194
+ models.Group.group_id == group_id,
2195
+ models.Group.admins.any(models.Administrator.user_id == admin_id)
2196
+ ).first()
2197
+ if not group:
2198
+ raise HTTPException(status_code=404, detail="Group not found")
2199
+
2200
+ admin = session.query(models.Administrator).filter(
2201
+ models.Administrator.user_id == admin_id,
2202
+ models.Administrator.groups.contains(group)
2203
+ ).first()
2204
+ if not admin:
2205
+ raise HTTPException(status_code=403, detail="You are not authorized to manage this group")
2206
+
2207
+ # Try to find user by username or email
2208
+ user = session.query(models.User).filter(
2209
+ (models.User.username == username)# | (models.User.email == username_or_email) # if they were different
2210
+ ).first()
2211
+ if not user:
2212
+ raise HTTPException(status_code=404, detail="User not found")
2213
+
2214
+ # Check if already a member or already invited
2215
+ if user in group.users:
2216
+ raise HTTPException(status_code=400, detail="User is already a member")
2217
+ existing_invite = session.query(models.GroupInvitation).filter(
2218
+ models.GroupInvitation.group_id == group_id,
2219
+ models.GroupInvitation.email == user.username
2220
+ ).first()
2221
+ if existing_invite:
2222
+ raise HTTPException(status_code=400, detail="User already invited")
2223
+
2224
+ # Create invitation
2225
+ token = secrets.token_urlsafe(32).lower()
2226
+ invitation = models.GroupInvitation(
2227
+ group_id=group_id,
2228
+ inviter_id=admin_id,
2229
+ token=token,
2230
+ email=user.username,
2231
+ datetime_expiration=datetime.utcnow() + timedelta(days=7),
2232
+ status='pending'
2233
+ )
2234
+ session.add(invitation)
2235
+ session.commit()
2236
+
2237
+ # Send email
2238
+ invitation_link = f"http://{WEB_URL}/accept-group-invitation?token={token}&user_id={user.user_id}"
2239
+ emailer.connect()
2240
+ emailer.send(
2241
+ subject="You’re Invited to Join a Group on Compair",
2242
+ sender=EMAIL_USER,
2243
+ receivers=[user.username],
2244
+ html=GROUP_INVITATION_TEMPLATE.replace(
2245
+ "{{ inviter_name }}", admin.user.name
2246
+ ).replace(
2247
+ "{{ group_name }}", group.name
2248
+ ).replace(
2249
+ "{{ invitation_link }}", invitation_link
2250
+ )
2251
+ )
2252
+ invitation.status = 'sent'
2253
+ session.commit()
2254
+ return {"message": "Invitation sent successfully"}
2255
+
2256
+ @router.post("/admin/invite_new_user")
2257
+ def invite_new_user_to_group(
2258
+ request: schema.InviteToGroupRequest,
2259
+ current_user: models.User = Depends(get_current_user)
2260
+ ) -> dict:
2261
+ admin_id = request.admin_id
2262
+ group_id = request.group_id
2263
+ email = request.email
2264
+ """Generate an invitation link to Compair and send it via email, logging a Group request to send on signup."""
2265
+ with compair.Session() as session:
2266
+ group = session.query(models.Group).filter(
2267
+ models.Group.group_id == group_id,
2268
+ models.Group.admins.any(models.Administrator.user_id == admin_id)
2269
+ ).first()
2270
+ if not group:
2271
+ raise HTTPException(status_code=404, detail="Group not found")
2272
+
2273
+ admin = session.query(models.Administrator).filter(
2274
+ models.Administrator.user_id == admin_id,
2275
+ models.Administrator.groups.contains(group)
2276
+ ).first()
2277
+ if not admin:
2278
+ raise HTTPException(status_code=403, detail="You are not authorized to manage this group")
2279
+
2280
+ referral_link = generate_referral_link(admin.user.referral_code)
2281
+
2282
+ # Send email notification
2283
+ emailer.connect()
2284
+ # Track pending group invitation for this email if group_id is provided
2285
+ if group_id:
2286
+ # Store a "pending" group invitation for this email
2287
+ token = secrets.token_urlsafe(32).lower()
2288
+ invitation = models.GroupInvitation(
2289
+ group_id=group_id,
2290
+ inviter_id=admin.user.user_id,
2291
+ token=token,
2292
+ email=email,
2293
+ datetime_expiration=datetime.utcnow() + timedelta(days=7),
2294
+ status="pending"
2295
+ )
2296
+ session.add(invitation)
2297
+ session.commit()
2298
+ # Send email notification
2299
+ emailer.send(
2300
+ subject="You're Invited to Compair!",
2301
+ sender=EMAIL_USER,
2302
+ receivers=[email],
2303
+ html=INDIVIDUAL_INVITATION_TEMPLATE.replace(
2304
+ "{{inviter_name}}", admin.user.name
2305
+ ).replace(
2306
+ "{{referral_link}}", referral_link,
2307
+ )
2308
+ )
2309
+
2310
+ return {"message": "Invitation to join Compair successful."}
2311
+
2312
+ @router.post("/admin/remove_member")
2313
+ def remove_member(
2314
+ request: schema.RemoveMemberRequest,
2315
+ current_user: models.User = Depends(get_current_user)
2316
+ ):
2317
+ """Remove a member from a group."""
2318
+ with compair.Session() as session:
2319
+ # Validate that the group exists
2320
+ group = session.query(models.Group).filter(models.Group.group_id == request.group_id).first()
2321
+ if not group:
2322
+ raise HTTPException(status_code=404, detail="Group not found")
2323
+
2324
+ # Validate that the current user is an admin of the group
2325
+ if not any(admin.user_id == current_user.user_id for admin in group.admins):
2326
+ raise HTTPException(status_code=403, detail="You are not authorized to manage this group")
2327
+
2328
+ # Validate that the user to be removed is a member of the group
2329
+ user = session.query(models.User).filter(models.User.user_id == request.user_id).first()
2330
+ if not user or user not in group.users:
2331
+ raise HTTPException(status_code=404, detail="User is not a member of this group")
2332
+
2333
+ # Remove the user from the group
2334
+ group.users.remove(user)
2335
+ session.commit()
2336
+
2337
+ # Log the activity
2338
+ log_activity(
2339
+ session=session,
2340
+ user_id=current_user.user_id,
2341
+ group_id=group.group_id,
2342
+ action="remove_member",
2343
+ object_id=user.user_id,
2344
+ object_name=user.name,
2345
+ object_type="user"
2346
+ )
2347
+
2348
+ return {"message": f"User {user.name} has been removed from the group"}
2349
+
2350
+ @router.get("/accept_group_invitation")
2351
+ def accept_group_invitation(
2352
+ token: str,
2353
+ current_user: models.User = Depends(get_current_user)
2354
+ ):
2355
+ """Accept a group invitation using a token."""
2356
+ with compair.Session() as session:
2357
+ invitation = session.query(models.GroupInvitation).filter(models.GroupInvitation.token == token).first()
2358
+ if not invitation:
2359
+ raise HTTPException(status_code=404, detail="Invalid or expired invitation")
2360
+
2361
+ if invitation.datetime_expiration < datetime.now(timezone.utc):
2362
+ invitation.status = "expired"
2363
+ session.commit()
2364
+ raise HTTPException(status_code=400, detail="Invitation has expired")
2365
+
2366
+ group = invitation.group
2367
+ user = session.query(models.User).filter(models.User.user_id == current_user.user_id).first()
2368
+
2369
+ # Add the user to the group
2370
+ group.users.append(user)
2371
+ invitation.status = "accepted"
2372
+ session.commit()
2373
+
2374
+ return {"message": "You have successfully joined the group"}
2375
+
2376
+ @router.get("/accept_team_invitation")
2377
+ def accept_team_invitation(
2378
+ token: str,
2379
+ current_user: models.User = Depends(get_current_user)
2380
+ ):
2381
+ """Accept a team invitation using a token."""
2382
+ require_feature(HAS_TEAM, "Team collaboration")
2383
+ with compair.Session() as session:
2384
+ invitation = session.query(models.TeamInvitation).filter(
2385
+ models.TeamInvitation.invitation_id == token,
2386
+ models.TeamInvitation.status == "sent",
2387
+ ).first()
2388
+ if not invitation:
2389
+ raise HTTPException(status_code=404, detail="Invalid or expired invitation.")
2390
+
2391
+ if hasattr(current_user, "team_id"):
2392
+ current_user.team_id = invitation.team_id
2393
+ invitation.status = "accepted"
2394
+ session.commit()
2395
+
2396
+ return {"message": "You have successfully joined the team!"}
2397
+
2398
+ def get_feedback_tooltip(
2399
+ feedback_id: str,
2400
+ ) -> str:
2401
+ """Retrieve feedback details for tooltip display."""
2402
+ with compair.Session() as session:
2403
+ feedbacks = session.query(models.Feedback).filter(models.Feedback.feedback_id == feedback_id).all()
2404
+ return " | ".join(f.feedback for f in feedbacks)
2405
+
2406
+ @router.get("/get_activity_feed")
2407
+ def get_activity_feed(
2408
+ user_id: str,
2409
+ page: int = 1,
2410
+ page_size: int = 10,
2411
+ include_own_activities: bool = True,
2412
+ current_user: models.User = Depends(get_current_user)
2413
+ ):
2414
+ """Retrieve recent activities for a user's groups."""
2415
+ require_feature(HAS_ACTIVITY, "Activity feed")
2416
+ if not IS_CLOUD:
2417
+ raise HTTPException(status_code=501, detail="Activity feed is only available in the Compair Cloud edition.")
2418
+ with compair.Session() as session:
2419
+ # Get user's groups
2420
+
2421
+ # Query recent activities related to user's groups
2422
+ group_ids = [g.group_id for g in current_user.groups]
2423
+
2424
+ q = (
2425
+ session.query(models.Activity)
2426
+ .filter(models.Activity.group_id.in_(group_ids))
2427
+ .order_by(models.Activity.timestamp.desc())
2428
+ )
2429
+ if not include_own_activities:
2430
+ q = q.filter(models.Activity.user_id != current_user.user_id)
2431
+
2432
+ total_count = q.count()
2433
+ offset = (page - 1) * page_size
2434
+ activities = q.offset(offset).limit(page_size).all()
2435
+
2436
+ return {
2437
+ "activities": [
2438
+ {
2439
+ "user": activity.user.name,
2440
+ "user_id": activity.user_id,
2441
+ "group_id": activity.group_id,
2442
+ "action": activity.action,
2443
+ "object": f"{activity.object_type} {activity.object_name}",
2444
+ "object_type": activity.object_type,
2445
+ "object_name": activity.object_name,
2446
+ "object_id": activity.object_id,
2447
+ "timestamp": activity.timestamp,
2448
+ "tooltip": get_feedback_tooltip(activity.object_id) if activity.action == "provided feedback" else None
2449
+ }
2450
+ for activity in activities
2451
+ ],
2452
+ "total_count": total_count
2453
+ }
2454
+
2455
+ @router.delete("/delete_group")
2456
+ def delete_group(
2457
+ group_id: str,
2458
+ current_user: models.User = Depends(get_current_user)
2459
+ ):
2460
+ """Delete a group. Allowed only if the current user is an admin of the group.
2461
+
2462
+ This removes the group and its associations. Documents remain but lose their link to this group.
2463
+ """
2464
+ with compair.Session() as session:
2465
+ group = session.query(models.Group).filter(models.Group.group_id == group_id).first()
2466
+ if not group:
2467
+ raise HTTPException(status_code=404, detail="Group not found")
2468
+ # Check admin rights
2469
+ is_admin = any(a.user_id == current_user.user_id for a in group.admins)
2470
+ if not is_admin:
2471
+ raise HTTPException(status_code=403, detail="Not authorized to delete this group")
2472
+ name = group.name
2473
+ session.delete(group)
2474
+ session.commit()
2475
+ # Log activity
2476
+ try:
2477
+ log_activity(
2478
+ session=session,
2479
+ user_id=current_user.user_id,
2480
+ group_id=group_id,
2481
+ action="delete",
2482
+ object_id=group_id,
2483
+ object_name=name,
2484
+ object_type="group",
2485
+ )
2486
+ except Exception:
2487
+ pass
2488
+ return {"message": "Group deleted", "group_id": group_id}
2489
+
2490
+ plan_to_id = {
2491
+ #"starter": "price_1Rpb5VEPPPghXLIJkgJyWBXb",
2492
+ #"standard": "price_1Qv7etEPPPghXLIJ5SR8KgXD",
2493
+ "individual_monthly": "price_1Ry34UEPPPghXLIJ4OLQNEna",#"price_1RwnRaEPPPghXLIJIPyeVxLq",
2494
+ "individual_annual": "price_1Ry34NEPPPghXLIJeftgEYOD",#"price_1RwoQkEPPPghXLIJyNUCRRo3",
2495
+ "team_monthly": "price_1Ry34HEPPPghXLIJoEyGRm7h",#"price_1RwoRDEPPPghXLIJt7uTCC7E",
2496
+ "team_annual": "price_1Ry34REPPPghXLIJXhLhwpBW"#"price_1RwoP9EPPPghXLIJo16OPgMu",
2497
+ }
2498
+
2499
+ @router.post("/create-checkout-session")
2500
+ def create_checkout_session(
2501
+ checkout_map: Mapping,
2502
+ current_user: models.User = Depends(get_current_user),
2503
+ billing: BillingProvider = Depends(get_billing),
2504
+ settings: Settings = Depends(get_settings_dependency),
2505
+ ):
2506
+ """Create a checkout session through the configured billing provider."""
2507
+
2508
+ require_cloud("Billing")
2509
+
2510
+ if not settings.billing_enabled:
2511
+ raise HTTPException(status_code=501, detail="Billing is not available in this edition.")
2512
+
2513
+ plan = checkout_map.get("plan", "individual_monthly")
2514
+ team_emails = checkout_map.get("team_emails", []) or []
2515
+ if isinstance(team_emails, str):
2516
+ team_emails = [email.strip() for email in team_emails.split(',') if email.strip()]
2517
+ if plan not in plan_to_id:
2518
+ raise HTTPException(status_code=400, detail="Unsupported plan selected.")
2519
+
2520
+ if plan.startswith("team") and len(team_emails) < 3:
2521
+ raise HTTPException(status_code=400, detail="Team plans require at least 4 emails.")
2522
+
2523
+ price_id = plan_to_id[plan]
2524
+
2525
+ quantity = 1 if plan.startswith("individual") else len(team_emails)
2526
+
2527
+ try:
2528
+ with compair.Session() as session:
2529
+ db_user = session.query(models.User).filter(models.User.user_id == current_user.user_id).first()
2530
+ if not db_user:
2531
+ raise HTTPException(status_code=404, detail="User not found.")
2532
+
2533
+ if not db_user.stripe_customer_id:
2534
+ log_event("stripe_customer_create_attempt", user_id=db_user.user_id)
2535
+ try:
2536
+ customer_id = billing.ensure_customer(
2537
+ user_email=db_user.username,
2538
+ user_id=db_user.user_id,
2539
+ )
2540
+ except NotImplementedError as exc:
2541
+ raise HTTPException(status_code=501, detail=str(exc)) from exc
2542
+ db_user.stripe_customer_id = customer_id
2543
+ session.commit()
2544
+ log_event("stripe_customer_created", user_id=db_user.user_id, customer_id=customer_id)
2545
+
2546
+ customer_id = db_user.stripe_customer_id
2547
+
2548
+ if plan.startswith("team"):
2549
+ require_feature(HAS_TEAM, "Team collaboration")
2550
+ team_name = checkout_map.get("team_name", f"{db_user.name}'s Team")
2551
+ team = models.Team(
2552
+ name=team_name,
2553
+ total_documents_limit=100 * len(team_emails),
2554
+ daily_feedback_limit=50 * len(team_emails),
2555
+ )
2556
+ session.add(team)
2557
+ session.commit()
2558
+
2559
+ for email in team_emails:
2560
+ existing_user = session.query(models.User).filter(models.User.username == email).first()
2561
+ if not existing_user:
2562
+ invitation = models.TeamInvitation(
2563
+ email=email,
2564
+ inviter_id=db_user.user_id,
2565
+ team_id=team.team_id,
2566
+ status="pending",
2567
+ datetime_created=datetime.now(timezone.utc),
2568
+ )
2569
+ session.add(invitation)
2570
+ session.commit()
2571
+
2572
+ log_event("stripe_checkout_session_start", user_id=db_user.user_id)
2573
+
2574
+ try:
2575
+ session_info = billing.create_checkout_session(
2576
+ customer_id=customer_id,
2577
+ price_id=price_id,
2578
+ qty=quantity,
2579
+ success_url=settings.stripe_success_url,
2580
+ cancel_url=settings.stripe_cancel_url,
2581
+ metadata={"plan": plan, "user_id": db_user.user_id},
2582
+ )
2583
+ except NotImplementedError as exc:
2584
+ raise HTTPException(status_code=501, detail=str(exc)) from exc
2585
+ except Exception as exc:
2586
+ raise HTTPException(status_code=500, detail=str(exc)) from exc
2587
+
2588
+ session_id = getattr(session_info, "id", None)
2589
+ session_url = getattr(session_info, "url", None)
2590
+ if session_id is None and hasattr(session_info, "get"):
2591
+ session_id = session_info.get("id")
2592
+ if session_url is None and hasattr(session_info, "get"):
2593
+ session_url = session_info.get("url")
2594
+ if not session_id or not session_url:
2595
+ raise HTTPException(status_code=500, detail="Billing provider did not return a checkout URL.")
2596
+
2597
+ if hasattr(db_user, "checkout_session"):
2598
+ db_user.checkout_session = session_id
2599
+ if hasattr(db_user, "plan"):
2600
+ db_user.plan = plan
2601
+ session.commit()
2602
+ log_event("stripe_checkout_session_created", user_id=db_user.user_id, session_id=session_id)
2603
+ return {"url": session_url}
2604
+ except HTTPException:
2605
+ raise
2606
+ except Exception as exc:
2607
+ raise HTTPException(status_code=500, detail=str(exc)) from exc
2608
+
2609
+
2610
+ @router.get("/redirect-to-checkout")
2611
+ def redirect_to_checkout(
2612
+ user_id: str,
2613
+ billing: BillingProvider = Depends(get_billing),
2614
+ settings: Settings = Depends(get_settings_dependency),
2615
+ ):
2616
+ require_cloud("Billing")
2617
+ if not settings.billing_enabled:
2618
+ raise HTTPException(status_code=404, detail="Not found")
2619
+
2620
+ with compair.Session() as session:
2621
+ user = session.query(models.User).filter(models.User.user_id == user_id).first()
2622
+ if not user or not user.checkout_session:
2623
+ raise HTTPException(status_code=500, detail=f"No checkout session found for user {user_id}")
2624
+
2625
+ session_id = user.checkout_session
2626
+ try:
2627
+ checkout_url = billing.get_checkout_url(session_id)
2628
+ except NotImplementedError as exc:
2629
+ raise HTTPException(status_code=501, detail=str(exc)) from exc
2630
+
2631
+ user.checkout_session = None
2632
+ session.commit()
2633
+
2634
+ return {"url": checkout_url}
2635
+
2636
+
2637
+
2638
+ def track_payment_method(fingerprint: str):
2639
+ """
2640
+ Track the number of credits associated with a payment fingerprint.
2641
+ """
2642
+ with compair.Session() as session:
2643
+ total_credits = session.query(models.User).filter(
2644
+ models.User.payment_fingerprint == fingerprint,
2645
+ models.User.referral_credits > 0
2646
+ ).count()
2647
+
2648
+ max_credits_per_payment = 3
2649
+ if total_credits >= max_credits_per_payment:
2650
+ raise HTTPException(
2651
+ status_code=400,
2652
+ detail="Referral credit limit reached for this payment method."
2653
+ )
2654
+
2655
+ def handle_successful_checkout(session):
2656
+ """
2657
+ Capture the payment method during Stripe Checkout.
2658
+ """
2659
+ print('made it to checkout event')
2660
+ print(session)
2661
+ customer_id = session["customer"]
2662
+ if customer_id is None:
2663
+ customer_id = session["customer_details"]["email"]
2664
+
2665
+ payment_method = session.get("payment_method") # Might not exist if using trial
2666
+ print(customer_id)
2667
+ print(payment_method)
2668
+ if payment_method:
2669
+ with compair.Session() as session:
2670
+ user = session.query(models.User).filter(
2671
+ models.User.stripe_customer_id == customer_id
2672
+ ).first()
2673
+ print(user)
2674
+ if user:
2675
+ user.payment_fingerprint = payment_method
2676
+ track_payment_method(user.payment_fingerprint)
2677
+ session.commit()
2678
+
2679
+ payment_status = session["payment_status"]
2680
+ if payment_status != "paid":
2681
+ return
2682
+
2683
+ with compair.Session() as db_session:
2684
+ user = db_session.query(models.User).filter(
2685
+ models.User.stripe_customer_id == customer_id
2686
+ ).first()
2687
+ if not user:
2688
+ return
2689
+
2690
+ if user.plan.startswith("team"):
2691
+ team = db_session.query(models.Team).filter(
2692
+ models.Team.ownership.has(owner_id=user.user_id)
2693
+ ).first()
2694
+ if not team:
2695
+ return
2696
+
2697
+ # Update invitations with the new team ID and send emails
2698
+ invitations = db_session.query(models.TeamInvitation).filter(
2699
+ models.TeamInvitation.inviter_id == user.user_id,
2700
+ models.TeamInvitation.status == "pending",
2701
+ ).all()
2702
+ for invitation in invitations:
2703
+ existing_user = db_session.query(models.User).filter(
2704
+ models.User.username == invitation.email
2705
+ ).first()
2706
+
2707
+ if existing_user:
2708
+ # Send team invitation email to existing user
2709
+ emailer.connect()
2710
+ emailer.send(
2711
+ subject="You're Invited to Join a Team on Compair",
2712
+ sender=EMAIL_USER,
2713
+ receivers=[invitation.email],
2714
+ html=f"""
2715
+ <p>{user.name} has invited you to join their team on Compair!</p>
2716
+ <p>Click <a href="https://{WEB_URL}/accept-invitation?token={invitation.invitation_id}">here</a> to join.</p>
2717
+ """
2718
+ )
2719
+ invitation.status = "sent"
2720
+ else:
2721
+ # Send invitation to join Compair
2722
+ send_invite(
2723
+ invitee_emails=invitation.email,
2724
+ current_user=user,
2725
+ )
2726
+
2727
+ def handle_successful_payment_intent(intent):
2728
+ """
2729
+ Capture the payment method when the first invoice is paid.
2730
+ """
2731
+ ### Get / Store payment intent to customer_id (if both available!) if payment_method not available
2732
+ customer_id = intent["customer"]
2733
+ payment_method = intent.get("payment_method")
2734
+ print("Got to intent")
2735
+ print(intent)
2736
+ print(payment_method)
2737
+ if payment_method:
2738
+ with compair.Session() as session:
2739
+ user = session.query(models.User).filter(
2740
+ models.User.stripe_customer_id == customer_id
2741
+ ).first()
2742
+ if user and hasattr(user, "payment_fingerprint"):
2743
+ user.payment_fingerprint = payment_method
2744
+ track_payment_method(user.payment_fingerprint)
2745
+ session.commit()
2746
+
2747
+ def handle_successful_invoice_payment(
2748
+ invoice,
2749
+ billing: BillingProvider,
2750
+ analytics: Analytics | None = None,
2751
+ ):
2752
+ """
2753
+ Capture the payment method when the first invoice is paid.
2754
+ """
2755
+ ### Use payment_intent to lookup customer (if latter is not available) to store payment_method
2756
+ customer_id = invoice["customer"]
2757
+ price_id = None
2758
+ for line in invoice["lines"]["data"]:
2759
+ if "price" in line and "id" in line["price"]:
2760
+ price_id = line["price"]["id"]
2761
+ break
2762
+
2763
+ plan = None
2764
+ for k, v in plan_to_id.items():
2765
+ if v == price_id:
2766
+ plan = k
2767
+ break
2768
+
2769
+ payment_method = invoice.get("payment_method")
2770
+ print("Got to payment")
2771
+ print(invoice)
2772
+ print(payment_method)
2773
+ if payment_method:
2774
+ with compair.Session() as session:
2775
+ user = session.query(models.User).filter(
2776
+ models.User.stripe_customer_id == customer_id
2777
+ ).first()
2778
+ if user:
2779
+ user.payment_fingerprint = payment_method
2780
+ track_payment_method(user.payment_fingerprint)
2781
+ session.commit()
2782
+
2783
+ with compair.Session() as session:
2784
+ user = session.query(models.User).filter(
2785
+ models.User.stripe_customer_id == customer_id
2786
+ ).first()
2787
+ print(f'invoice user: {user}')
2788
+ if user:
2789
+ # user.stripe_subscription_id = subscription_id ## Can track if needed
2790
+ user.status = 'active' # Mark the user as subscribed
2791
+ if hasattr(user, "last_payment_date"):
2792
+ user.last_payment_date = datetime.now(timezone.utc)
2793
+ user.status_change_date = datetime.now(timezone.utc)
2794
+ if plan and hasattr(user, "plan"):
2795
+ user.plan = plan
2796
+ print(user.status)
2797
+ # Handle referrer logic
2798
+ if hasattr(user, "referred_by") and user.referred_by:
2799
+ print(user.referred_by)
2800
+ referrer = session.query(models.User).filter(
2801
+ models.User.user_id == user.referred_by
2802
+ ).first()
2803
+ print(referrer)
2804
+ if referrer and hasattr(referrer, "pending_referral_credits") and referrer.pending_referral_credits > 0:
2805
+ # Convert pending credits to earned credits
2806
+ if hasattr(referrer, "referral_credits"):
2807
+ referrer.referral_credits += 10
2808
+ referrer.pending_referral_credits -= 10
2809
+ # Send email notification
2810
+ emailer.connect()
2811
+ emailer.send(
2812
+ subject="You've Earned a Free Month!",
2813
+ sender=EMAIL_USER,
2814
+ receivers=[referrer.username],
2815
+ html=REFERRAL_CREDIT_TEMPLATE.replace(
2816
+ "{{user_name}}", referrer.name
2817
+ ).replace(
2818
+ "{{referral_credits}}", str(referrer.referral_credits),
2819
+ )
2820
+ )
2821
+ # Create and apply a $15 coupon in Stripe ($25 if team plan)
2822
+ amount = 15
2823
+ if referrer.plan == 'team':
2824
+ amount=25
2825
+ try:
2826
+ coupon_id = billing.create_coupon(amount)
2827
+ print(f'Coupon ID 2:{coupon_id}')
2828
+ if referrer.stripe_customer_id:
2829
+ billing.apply_coupon(customer_id=referrer.stripe_customer_id, coupon_id=coupon_id)
2830
+ except NotImplementedError:
2831
+ pass
2832
+
2833
+ if analytics:
2834
+ try:
2835
+ analytics.track("subscription_created", user.user_id)
2836
+ except Exception as exc:
2837
+ print(f"analytics track failed: {exc}")
2838
+
2839
+ session.commit()
2840
+
2841
+ def handle_subscription_cancellation(subscription):
2842
+ """
2843
+ Revoke access if a subscription is canceled.
2844
+ """
2845
+ customer_id = subscription["customer"]
2846
+ print('Cancelled')
2847
+ print(subscription)
2848
+ with compair.Session() as session:
2849
+ user = session.query(models.User).filter(
2850
+ models.User.stripe_customer_id == customer_id
2851
+ ).first()
2852
+
2853
+ if user:
2854
+ user.status = 'suspended' # Remove access
2855
+ user.status_change_date = datetime.now(timezone.utc)
2856
+ session.commit()
2857
+
2858
+ def handle_failed_payment(invoice):
2859
+ """
2860
+ Handle failed payment attempts.
2861
+ """
2862
+ customer_id = invoice["customer"]
2863
+ print('Failure')
2864
+ print(invoice)
2865
+ with compair.Session() as session:
2866
+ user = session.query(models.User).filter(
2867
+ models.User.stripe_customer_id == customer_id
2868
+ ).first()
2869
+
2870
+ if user:
2871
+ user.status = 'suspended' # Suspend access until payment is made
2872
+ user.status_change_date = datetime.now(timezone.utc)
2873
+ session.commit()
2874
+
2875
+ @router.post("/webhook")
2876
+ async def stripe_webhook(
2877
+ request: Request,
2878
+ billing: BillingProvider = Depends(get_billing),
2879
+ analytics: Analytics = Depends(get_analytics),
2880
+ settings: Settings = Depends(get_settings_dependency),
2881
+ ):
2882
+ if not settings.billing_enabled:
2883
+ raise HTTPException(status_code=404, detail="Not found")
2884
+
2885
+ payload = await request.body()
2886
+ sig_header = request.headers.get("stripe-signature")
2887
+ #print(f'payload: {payload}')
2888
+ try:
2889
+ event = billing.construct_event(payload, sig_header)
2890
+ print(f'webhook event: {event["type"]}')
2891
+
2892
+ log_event("stripe_webhook_start", event_type=event['type'], data=event['data']['object'])
2893
+
2894
+ if event["type"] == "payment_intent.succeeded":
2895
+ intent = event["data"]["object"]
2896
+ handle_successful_payment_intent(intent)
2897
+
2898
+ elif event["type"] == "invoice.payment_succeeded":
2899
+ invoice = event["data"]["object"]
2900
+ handle_successful_invoice_payment(invoice, billing, analytics)
2901
+
2902
+ elif event["type"] == "customer.subscription.deleted":
2903
+ subscription = event["data"]["object"]
2904
+ handle_subscription_cancellation(subscription)
2905
+
2906
+ elif event["type"] == "invoice.payment_failed":
2907
+ invoice = event["data"]["object"]
2908
+ handle_failed_payment(invoice)
2909
+
2910
+ log_event("stripe_webhook_complete", event_type=event['type'], data=event['data']['object'])
2911
+
2912
+ return {"status": "success"}
2913
+ except Exception as e:
2914
+ raise HTTPException(status_code=400, detail=str(e))
2915
+
2916
+
2917
+
2918
+ def generate_referral_link(referral_code: str) -> str:
2919
+ """
2920
+ Generate a referral link using the user's referral code.
2921
+ """
2922
+ require_cloud("Referral program")
2923
+ base_url = f"http://{WEB_URL}/login"
2924
+ return f"{base_url}?ref={referral_code}"
2925
+
2926
+ @router.post("/send-invite")
2927
+ def send_invite(
2928
+ invitee_emails: str = Form(...),
2929
+ group_id: str = Form(None),
2930
+ current_user: models.User = Depends(get_current_user),
2931
+ ):
2932
+ """
2933
+ Send an email invitation with the user's referral code.
2934
+ If group_id is provided, track the invitation for group auto-invite on sign-up.
2935
+ """
2936
+ require_cloud("Referral program")
2937
+ invitee_emails = invitee_emails.split(',')
2938
+ with compair.Session() as session:
2939
+
2940
+ referral_link = generate_referral_link(current_user.referral_code)
2941
+
2942
+ # Send email notification
2943
+ emailer.connect()
2944
+ for invitee_email in invitee_emails:
2945
+ # Track pending group invitation for this email if group_id is provided
2946
+ if group_id:
2947
+ # Store a "pending" group invitation for this email
2948
+ token = secrets.token_urlsafe(32).lower()
2949
+ invitation = models.GroupInvitation(
2950
+ group_id=group_id,
2951
+ inviter_id=current_user.user_id,
2952
+ token=token,
2953
+ email=invitee_email,
2954
+ datetime_expiration=datetime.utcnow() + timedelta(days=7),
2955
+ status="pending"
2956
+ )
2957
+ session.add(invitation)
2958
+ session.commit()
2959
+ # Send email notification
2960
+ emailer.send(
2961
+ subject="You're Invited to Compair!",
2962
+ sender=EMAIL_USER,
2963
+ receivers=[invitee_email],
2964
+ html=INDIVIDUAL_INVITATION_TEMPLATE.replace(
2965
+ "{{inviter_name}}", current_user.name
2966
+ ).replace(
2967
+ "{{referral_link}}", referral_link,
2968
+ )
2969
+ )
2970
+
2971
+ return {"message": "Invitation sent successfully"}
2972
+
2973
+ @router.post("/get-customer-portal")
2974
+ def get_customer_portal(
2975
+ current_user: models.User = Depends(get_current_user),
2976
+ billing: BillingProvider = Depends(get_billing),
2977
+ settings: Settings = Depends(get_settings_dependency),
2978
+ ):
2979
+ """Generate a billing portal session link through the billing provider."""
2980
+
2981
+ if not settings.billing_enabled:
2982
+ raise HTTPException(status_code=501, detail="Billing is not available in this edition.")
2983
+
2984
+ return_url = f"http://{WEB_URL}/home" if WEB_URL else "https://compair.sh/home"
2985
+
2986
+ try:
2987
+ with compair.Session() as session:
2988
+ user = session.query(models.User).filter(models.User.user_id == current_user.user_id).first()
2989
+ if not user or not user.stripe_customer_id:
2990
+ raise HTTPException(status_code=400, detail="No billing profile found for this user.")
2991
+
2992
+ try:
2993
+ portal_url = billing.create_customer_portal(
2994
+ customer_id=user.stripe_customer_id,
2995
+ return_url=return_url,
2996
+ )
2997
+ except NotImplementedError as exc:
2998
+ raise HTTPException(status_code=501, detail=str(exc)) from exc
2999
+
3000
+ return {"url": portal_url}
3001
+ except HTTPException:
3002
+ raise
3003
+ except Exception as exc:
3004
+ raise HTTPException(status_code=500, detail=str(exc)) from exc
3005
+
3006
+
3007
+
3008
+
3009
+ @router.post("/upload/profile")
3010
+ async def upload_profile_image(
3011
+ upload_type: str = Form(...),
3012
+ file: UploadFile = File(...),
3013
+ current_user: models.User = Depends(get_current_user)
3014
+ ) -> Mapping[str, str]:
3015
+ if upload_type!='profile':
3016
+ raise HTTPException(status_code=400, detail="Invalid upload type")
3017
+
3018
+ allowed_types = ["image/jpg", "image/jpeg", "image/png", "image/webp"]
3019
+ if file.content_type not in allowed_types:
3020
+ raise HTTPException(status_code=400, detail="Invalid file type")
3021
+
3022
+ # Upload to Cloudflare Images
3023
+ headers = {
3024
+ "Authorization": f"Bearer {CLOUDFLARE_IMAGES_TOKEN}"
3025
+ }
3026
+ data = {
3027
+ "requireSignedURLs": "false"
3028
+ }
3029
+ files = {
3030
+ "file": (file.filename, await file.read(), file.content_type)
3031
+ }
3032
+
3033
+ async with httpx.AsyncClient() as client:
3034
+ response = await client.post(
3035
+ CLOUDFLARE_IMAGES_UPLOAD_URL,
3036
+ headers=headers,
3037
+ data=data,
3038
+ files=files,
3039
+ timeout=30
3040
+ )
3041
+ if response.status_code != 200:
3042
+ raise HTTPException(status_code=500, detail="Cloudflare Images upload failed")
3043
+
3044
+ result = response.json()
3045
+ if not result.get("success"):
3046
+ raise HTTPException(status_code=500, detail="Cloudflare Images upload failed")
3047
+
3048
+ image_id = result["result"]["id"]
3049
+
3050
+ with compair.Session() as session:
3051
+ user = session.query(models.User).filter(models.User.user_id == current_user.user_id).first()
3052
+ user.profile_image = image_id
3053
+ session.commit()
3054
+
3055
+ # Read file bytes to get size
3056
+ file_bytes = await file.read()
3057
+ file_size = len(file_bytes)
3058
+
3059
+ # Log the upload event
3060
+ log_event(
3061
+ "cloudflare_upload",
3062
+ upload_type="profile_image",
3063
+ user_id=current_user.user_id,
3064
+ file_size=file_size,
3065
+ file_key=image_id,
3066
+ content_type=file.content_type
3067
+ )
3068
+
3069
+ return {"image_id": image_id}
3070
+
3071
+
3072
+ @router.get("/get_profile_image")
3073
+ def get_profile_image(
3074
+ variant: str = Query("public", enum=["public", "avatar", "preview"]),
3075
+ current_user: models.User = Depends(get_current_user)
3076
+ ) -> Mapping[str, str]:
3077
+ with compair.Session() as session:
3078
+ image_id = current_user.profile_image
3079
+ if not image_id:
3080
+ raise HTTPException(status_code=400, detail="No profile image found")
3081
+ image_url = f"{CLOUDFLARE_IMAGES_BASE_URL}/{image_id}/{variant}"
3082
+ return {"url": image_url}
3083
+
3084
+
3085
+ @router.post("/upload/file")
3086
+ async def upload_file(
3087
+ document_id: str = Form(...),
3088
+ upload_type: str = Form(...),
3089
+ file: UploadFile = File(...),
3090
+ current_user: models.User = Depends(get_current_user),
3091
+ storage: StorageProvider = Depends(get_storage),
3092
+ ) -> Mapping[str, str]:
3093
+ if upload_type != 'file':
3094
+ raise HTTPException(status_code=400, detail="Invalid upload type")
3095
+
3096
+ allowed_types = ["application/pdf"]#, "document/docx"]
3097
+ if file.content_type not in allowed_types:
3098
+ raise HTTPException(status_code=400, detail="Invalid file type")
3099
+
3100
+ file_name = file.filename
3101
+ file_name_hash = hashlib.sha256(f"{file_name}".encode()).hexdigest()
3102
+ file_key = f"{upload_type}/{current_user.user_id}/{file_name_hash}/{document_id}"
3103
+
3104
+ # Read file bytes to get size and reset file pointer before upload
3105
+ file_bytes = await file.read()
3106
+ file_size = len(file_bytes)
3107
+ await file.seek(0)
3108
+
3109
+ content_type = file.content_type or "application/octet-stream"
3110
+
3111
+ try:
3112
+ upload_url = storage.put_file(file_key, file.file, content_type)
3113
+ except NotImplementedError as exc:
3114
+ raise HTTPException(status_code=501, detail=str(exc)) from exc
3115
+ except Exception as exc: # pragma: no cover - surfaced to client
3116
+ raise HTTPException(status_code=500, detail=str(exc)) from exc
3117
+
3118
+ with compair.Session() as session:
3119
+ # Fetch the document
3120
+ document = session.query(models.Document).filter(models.Document.document_id == document_id).first()
3121
+ if not document:
3122
+ raise HTTPException(status_code=404, detail="Document not found")
3123
+
3124
+ document.file_key = file_key
3125
+ session.commit()
3126
+
3127
+ # Log the upload event
3128
+ log_event(
3129
+ "cloudflare_upload",
3130
+ upload_type="document_file",
3131
+ user_id=current_user.user_id,
3132
+ document_id=document_id,
3133
+ file_size=file_size,
3134
+ file_key=file_key,
3135
+ content_type=content_type
3136
+ )
3137
+
3138
+ return {"url": upload_url}
3139
+
3140
+
3141
+ @router.post("/upload/group")
3142
+ async def upload_group_image(
3143
+ group_id: str,
3144
+ upload_type: str,
3145
+ file: UploadFile = File(...)
3146
+ ) -> Mapping[str, str]:
3147
+ if upload_type!='group':
3148
+ raise HTTPException(status_code=400, detail="Invalid upload type")
3149
+
3150
+ allowed_types = ["image/jpg", "image/jpeg", "image/png", "image/webp"]
3151
+ if file.content_type not in allowed_types:
3152
+ raise HTTPException(status_code=400, detail="Invalid file type")
3153
+
3154
+ # Upload to Cloudflare Images
3155
+ headers = {
3156
+ "Authorization": f"Bearer {CLOUDFLARE_IMAGES_TOKEN}"
3157
+ }
3158
+ data = {
3159
+ "requireSignedURLs": "false"
3160
+ }
3161
+ files = {
3162
+ "file": (file.filename, await file.read(), file.content_type)
3163
+ }
3164
+
3165
+ async with httpx.AsyncClient() as client:
3166
+ response = await client.post(
3167
+ CLOUDFLARE_IMAGES_UPLOAD_URL,
3168
+ headers=headers,
3169
+ data=data,
3170
+ files=files,
3171
+ timeout=30
3172
+ )
3173
+ if response.status_code != 200:
3174
+ raise HTTPException(status_code=500, detail="Cloudflare Images upload failed")
3175
+
3176
+ result = response.json()
3177
+ if not result.get("success"):
3178
+ raise HTTPException(status_code=500, detail="Cloudflare Images upload failed")
3179
+
3180
+ image_id = result["result"]["id"]
3181
+
3182
+ with compair.Session() as session:
3183
+ group = session.query(models.Group).filter(models.Group.group_id == group_id).first()
3184
+ group.group_image = image_id
3185
+ session.commit()
3186
+
3187
+ # Read file bytes to get size
3188
+ file_bytes = await file.read()
3189
+ file_size = len(file_bytes)
3190
+
3191
+ # Log the upload event
3192
+ log_event(
3193
+ "cloudflare_upload",
3194
+ upload_type="group_image",
3195
+ group_id=group_id,
3196
+ file_size=file_size,
3197
+ file_key=image_id,
3198
+ content_type=file.content_type
3199
+ )
3200
+
3201
+ @router.post("/send-feature-announcement")
3202
+ def send_feature_announcement(admin_key: str = Header(None)):
3203
+ """
3204
+ Manually trigger a feature announcement email to churned users.
3205
+ Requires an admin API key.
3206
+ """
3207
+ if not IS_CLOUD or send_feature_announcement_task is None:
3208
+ raise HTTPException(status_code=501, detail="Feature announcements require the Compair Cloud edition.")
3209
+ if admin_key != ADMIN_API_KEY:
3210
+ raise HTTPException(status_code=403, detail="Unauthorized")
3211
+
3212
+ send_feature_announcement_task.delay() # Run email sending as an async Celery task
3213
+
3214
+ return {"message": "Feature announcement email campaign triggered successfully."}
3215
+
3216
+
3217
+ @router.post("/retrial/{user_id}")
3218
+ def grant_retrial(user_id: str):
3219
+ """
3220
+ Reactivate a churned user's trial for 2 weeks.
3221
+ """
3222
+ if not IS_CLOUD:
3223
+ raise HTTPException(status_code=501, detail="Re-trials require the Compair Cloud edition.")
3224
+ with Session() as session:
3225
+ user = session.query(models.User).filter(models.User.user_id == user_id).first()
3226
+
3227
+ if not user or user.status != "suspended":
3228
+ raise HTTPException(status_code=400, detail="User not eligible for a re-trial.")
3229
+
3230
+ if user.retrial_count >= 1:
3231
+ raise HTTPException(status_code=403, detail="Re-trial limit reached.")
3232
+
3233
+ # Reactivate trial
3234
+ user.status = "trial"
3235
+ user.status_change_date = datetime.now(timezone.utc)
3236
+ user.trial_expiration_date = datetime.now(timezone.utc) + timedelta(days=14)
3237
+ user.retrial_count += 1 # Increment re-trial count
3238
+ user.last_retrial_date = datetime.now(timezone.utc)
3239
+ session.commit()
3240
+
3241
+ return {"message": "Your re-trial has been activated!"}
3242
+
3243
+ @router.post("/documents/{document_id}/notes")
3244
+ def create_note(
3245
+ document_id: str,
3246
+ content: str = Form(...),
3247
+ group_id: str = Form(None),
3248
+ current_user: models.User = Depends(get_current_user)
3249
+ ):
3250
+ """Create a Note for a Document. User must be in the Document's group. Also chunk/embed Note content."""
3251
+ with compair.Session() as session:
3252
+ document = session.query(models.Document).filter(models.Document.document_id == document_id).first()
3253
+ if not document:
3254
+ raise HTTPException(status_code=404, detail="Document not found")
3255
+ # Check group membership
3256
+ doc_group_ids = [g.group_id for g in document.groups]
3257
+ user_group_ids = [g.group_id for g in current_user.groups]
3258
+ if not set(doc_group_ids) & set(user_group_ids):
3259
+ raise HTTPException(status_code=403, detail="User not in document's group")
3260
+ note = models.Note(
3261
+ document_id=document_id,
3262
+ author_id=current_user.user_id,
3263
+ group_id=group_id or (doc_group_ids[0] if doc_group_ids else None),
3264
+ content=content,
3265
+ datetime_created=datetime.now(timezone.utc)
3266
+ )
3267
+ session.add(note)
3268
+ session.commit()
3269
+ session.refresh(note)
3270
+ # Chunk and embed Note content
3271
+ note_text_chunks = chunk_text(content)
3272
+ embedder = Embedder()
3273
+ for text in note_text_chunks:
3274
+ chunk_hash = hash(text)
3275
+ embedding = create_embedding(embedder, text, user=current_user)
3276
+ note_chunk = models.Chunk(
3277
+ hash=str(chunk_hash),
3278
+ document_id=document_id,
3279
+ note_id=note.note_id,
3280
+ chunk_type="note",
3281
+ content=text,
3282
+ )
3283
+ note_chunk.embedding = embedding
3284
+ session.add(note_chunk)
3285
+ session.commit()
3286
+
3287
+ log_activity(
3288
+ session=session,
3289
+ user_id=current_user.user_id,
3290
+ group_id=note.group_id,
3291
+ action="create",
3292
+ object_id=note.note_id,
3293
+ object_name=f"Note on {document.title}",
3294
+ object_type="note"
3295
+ )
3296
+
3297
+ return schema.Note(
3298
+ note_id=note.note_id,
3299
+ document_id=note.document_id,
3300
+ author_id=note.author_id,
3301
+ group_id=note.group_id,
3302
+ content=note.content,
3303
+ datetime_created=note.datetime_created,
3304
+ author=schema.User(
3305
+ user_id=current_user.user_id,
3306
+ username=current_user.username,
3307
+ name=current_user.name,
3308
+ datetime_registered=current_user.datetime_registered,
3309
+ status=current_user.status,
3310
+ profile_image=current_user.profile_image,
3311
+ role=current_user.role,
3312
+ )
3313
+ )
3314
+
3315
+ @router.get("/documents/{document_id}/notes")
3316
+ def list_notes(
3317
+ document_id: str,
3318
+ current_user: models.User = Depends(get_current_user)
3319
+ ) -> list[schema.Note]:
3320
+ """List all Notes for a Document."""
3321
+ with compair.Session() as session:
3322
+ notes = session.query(models.Note).filter(models.Note.document_id == document_id).all()
3323
+ result = []
3324
+ for note in notes:
3325
+ author = session.query(models.User).filter(models.User.user_id == note.author_id).first()
3326
+ result.append(schema.Note(
3327
+ note_id=note.note_id,
3328
+ document_id=note.document_id,
3329
+ author_id=note.author_id,
3330
+ group_id=note.group_id,
3331
+ content=note.content,
3332
+ datetime_created=note.datetime_created,
3333
+ author=schema.User(
3334
+ user_id=author.user_id,
3335
+ username=author.username,
3336
+ name=author.name,
3337
+ datetime_registered=author.datetime_registered,
3338
+ status=author.status,
3339
+ profile_image=author.profile_image,
3340
+ role=author.role,
3341
+ ) if author else None
3342
+ ))
3343
+ return result
3344
+
3345
+ @router.get("/notes/{note_id}")
3346
+ def get_note(
3347
+ note_id: str,
3348
+ current_user: models.User = Depends(get_current_user)
3349
+ ) -> schema.Note:
3350
+ """Get a single Note by ID."""
3351
+ with compair.Session() as session:
3352
+ note = session.query(models.Note).filter(models.Note.note_id == note_id).first()
3353
+ if not note:
3354
+ raise HTTPException(status_code=404, detail="Note not found")
3355
+ author = session.query(models.User).filter(models.User.user_id == note.author_id).first()
3356
+ return schema.Note(
3357
+ note_id=note.note_id,
3358
+ document_id=note.document_id,
3359
+ author_id=note.author_id,
3360
+ group_id=note.group_id,
3361
+ content=note.content,
3362
+ datetime_created=note.datetime_created,
3363
+ author=schema.User(
3364
+ user_id=author.user_id,
3365
+ username=author.username,
3366
+ name=author.name,
3367
+ datetime_registered=author.datetime_registered,
3368
+ status=author.status,
3369
+ profile_image=author.profile_image,
3370
+ role=author.role,
3371
+ ) if author else None
3372
+ )
3373
+
3374
+ @router.get("/documents/{document_id}/file-url")
3375
+ def get_document_file_url(
3376
+ document_id: str,
3377
+ current_user: models.User = Depends(get_current_user),
3378
+ storage: StorageProvider = Depends(get_storage),
3379
+ ):
3380
+ """Return the file URL for a Document."""
3381
+ with compair.Session() as session:
3382
+ doc = session.query(models.Document).filter(models.Document.document_id == document_id).first()
3383
+ if not doc or not doc.file_key:
3384
+ raise HTTPException(status_code=404, detail="File not found for this document.")
3385
+ file_url = storage.build_url(doc.file_key)
3386
+ return {"file_url": file_url}
3387
+
3388
+ @router.get("/documents/{document_id}/image-url")
3389
+ def get_document_image_url(
3390
+ document_id: str,
3391
+ current_user: models.User = Depends(get_current_user),
3392
+ storage: StorageProvider = Depends(get_storage),
3393
+ ):
3394
+ """Return the preview image URL for a Document."""
3395
+ with compair.Session() as session:
3396
+ doc = session.query(models.Document).filter(models.Document.document_id == document_id).first()
3397
+ if not doc or not doc.image_key:
3398
+ raise HTTPException(status_code=404, detail="Image not found for this document.")
3399
+ image_url = storage.build_url(doc.image_key)
3400
+ return {"image_url": image_url}
3401
+
3402
+
3403
+ def sanitize_filename(filename: str) -> str:
3404
+ # Replace non-ASCII characters with underscore
3405
+ return re.sub(r'[^\x00-\x7F]+', '_', filename)
3406
+
3407
+
3408
+ @router.post("/documents/{document_id}/generate-download-token")
3409
+ def generate_download_token(
3410
+ document_id: str,
3411
+ current_user: models.User = Depends(get_current_user)
3412
+ ):
3413
+ # Check permissions as in your download endpoint
3414
+ with compair.Session() as session:
3415
+ doc = session.query(models.Document).filter(models.Document.document_id == document_id).first()
3416
+ if not doc or not doc.file_key:
3417
+ raise HTTPException(status_code=404, detail="File not found for this document.")
3418
+
3419
+ # Authorization check (same as download endpoint)
3420
+ if current_user.user_id == doc.author_id:
3421
+ pass
3422
+ elif doc.is_published:
3423
+ doc_group_ids = {g.group_id for g in doc.groups}
3424
+ user_group_ids = {g.group_id for g in current_user.groups}
3425
+ if not doc_group_ids & user_group_ids:
3426
+ raise HTTPException(status_code=403, detail="Not authorized to download this file.")
3427
+ else:
3428
+ raise HTTPException(status_code=403, detail="Not authorized to download this file.")
3429
+
3430
+ if not HAS_REDIS:
3431
+ raise HTTPException(status_code=501, detail="Secure download links require Redis, which is unavailable in the core edition.")
3432
+
3433
+ token = secrets.token_urlsafe(32)
3434
+ key = f"download_token:{token}"
3435
+ redis_client.setex(key, 300, document_id)
3436
+ return {"download_url": f"/documents/download/{token}"}
3437
+
3438
+
3439
+ @router.get("/documents/download/{token}")
3440
+ def download_document_with_token(
3441
+ token: str,
3442
+ storage: StorageProvider = Depends(get_storage),
3443
+ ):
3444
+ if not HAS_REDIS:
3445
+ raise HTTPException(status_code=501, detail="Secure download links require Redis, which is unavailable in the core edition.")
3446
+
3447
+ key = f"download_token:{token}"
3448
+ value = redis_client.get(key) if redis_client else None
3449
+ document_id = value.decode('utf-8') if value else None
3450
+ if not document_id:
3451
+ raise HTTPException(status_code=403, detail="Invalid or expired token")
3452
+ redis_client.delete(key)
3453
+ with compair.Session() as session:
3454
+ doc = session.query(models.Document).filter(models.Document.document_id == document_id).first()
3455
+ if not doc or not doc.file_key:
3456
+ raise HTTPException(status_code=404, detail="File not found for this document.")
3457
+
3458
+ safe_title = sanitize_filename(doc.title or 'file')
3459
+ try:
3460
+ file_obj, content_type = storage.get_file(doc.file_key)
3461
+ except FileNotFoundError as exc:
3462
+ raise HTTPException(status_code=404, detail="Stored file could not be located.") from exc
3463
+
3464
+ return StreamingResponse(
3465
+ file_obj,
3466
+ media_type=content_type,
3467
+ headers={
3468
+ "Content-Disposition": f"attachment; filename={safe_title or 'file'}"
3469
+ }
3470
+ )
3471
+
3472
+
3473
+ @router.get("/documents/{document_id}/download")
3474
+ def download_document_file(
3475
+ document_id: str,
3476
+ current_user: models.User = Depends(get_current_user),
3477
+ storage: StorageProvider = Depends(get_storage),
3478
+ ):
3479
+ with compair.Session() as session:
3480
+ doc = session.query(models.Document).filter(models.Document.document_id == document_id).first()
3481
+
3482
+ safe_title = sanitize_filename(doc.title or 'file')
3483
+
3484
+ if not doc or not doc.file_key:
3485
+ raise HTTPException(status_code=404, detail="File not found for this document.")
3486
+
3487
+ # Authorization check
3488
+ if current_user.user_id == doc.author_id:
3489
+ pass
3490
+ elif doc.is_published:
3491
+ # Check group membership
3492
+ doc_group_ids = {g.group_id for g in doc.groups}
3493
+ user_group_ids = {g.group_id for g in current_user.groups}
3494
+ if not doc_group_ids & user_group_ids:
3495
+ raise HTTPException(status_code=403, detail="Not authorized to download this file.")
3496
+
3497
+ # Fetch file from R2 and stream to user
3498
+ try:
3499
+ file_obj, content_type = storage.get_file(doc.file_key)
3500
+ except FileNotFoundError as exc:
3501
+ raise HTTPException(status_code=404, detail="Stored file could not be located.") from exc
3502
+
3503
+ return StreamingResponse(
3504
+ file_obj,
3505
+ media_type=content_type,
3506
+ headers={
3507
+ "Content-Disposition": f"attachment; filename={safe_title or 'file'}"
3508
+ }
3509
+ )
3510
+
3511
+
3512
+ @router.post("/help-request")
3513
+ def submit_help_request(
3514
+ content: str = Form(...),
3515
+ current_user: models.User = Depends(get_current_user)
3516
+ ):
3517
+ with compair.Session() as session:
3518
+ help_request = models.HelpRequest(
3519
+ user_id=current_user.user_id,
3520
+ content=content,
3521
+ datetime_created=datetime.now(timezone.utc)
3522
+ )
3523
+ session.add(help_request)
3524
+ session.commit()
3525
+ if not IS_CLOUD:
3526
+ raise HTTPException(status_code=501, detail="Help request emails require the Compair Cloud edition.")
3527
+ send_help_request_email.delay(help_request.request_id)
3528
+ return {"message": "Your request has been submitted. Our team will get back to you soon."}
3529
+
3530
+
3531
+ @router.post("/deactivate-account")
3532
+ def submit_deactivate_request(
3533
+ notice: str = Form(...),
3534
+ current_user: models.User = Depends(get_current_user)
3535
+ ):
3536
+ with compair.Session() as session:
3537
+ deactivate_request = models.DeactivateRequest(
3538
+ user_id=current_user.user_id,
3539
+ notice=notice,
3540
+ datetime_created=datetime.now(timezone.utc)
3541
+ )
3542
+ session.add(deactivate_request)
3543
+ session.commit()
3544
+ if not IS_CLOUD:
3545
+ raise HTTPException(status_code=501, detail="Deactivate request emails require the Compair Cloud edition.")
3546
+ send_deactivate_request_email.delay(deactivate_request.request_id)
3547
+ return {"message": f"We’ve received your request and will delete your account and data shortly. If you change your mind, reach out within 24 hours at {EMAIL_USER}."}
3548
+
3549
+
3550
+ CORE_PATHS: set[str] = {
3551
+ "/sign-up",
3552
+ "/verify-email",
3553
+ "/login",
3554
+ "/forgot-password",
3555
+ "/reset-password",
3556
+ "/load_session",
3557
+ "/update_user",
3558
+ "/load_groups",
3559
+ "/load_group",
3560
+ "/create_group",
3561
+ "/join_group",
3562
+ "/load_group_users",
3563
+ "/delete_group",
3564
+ "/load_documents",
3565
+ "/load_document",
3566
+ "/load_document_by_id",
3567
+ "/load_user_files",
3568
+ "/create_doc",
3569
+ "/update_doc",
3570
+ "/publish_doc",
3571
+ "/delete_doc",
3572
+ "/delete_docs",
3573
+ "/process_doc",
3574
+ "/status/{task_id}",
3575
+ "/upload/ocr-file",
3576
+ "/ocr-file-result/{task_id}",
3577
+ "/load_chunks",
3578
+ "/load_references",
3579
+ "/load_feedback",
3580
+ "/documents/{document_id}/feedback",
3581
+ "/get_activity_feed",
3582
+ }
3583
+
3584
+ for route in router.routes:
3585
+ if isinstance(route, APIRoute) and route.path in CORE_PATHS:
3586
+ core_router.routes.append(route)
3587
+
3588
+
3589
+ def create_fastapi_app():
3590
+ """Backwards-compatible app factory for running this module directly."""
3591
+ from fastapi import FastAPI
3592
+
3593
+ fastapi_app = FastAPI()
3594
+ fastapi_app.include_router(router)
3595
+ return fastapi_app
3596
+
3597
+
3598
+ app = create_fastapi_app()