compair-core 0.3.7__py3-none-any.whl → 0.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of compair-core might be problematic. Click here for more details.

Files changed (40) hide show
  1. compair_core/__init__.py +8 -0
  2. compair_core/api.py +3357 -0
  3. {compair → compair_core/compair}/__init__.py +10 -14
  4. {server → compair_core/server}/app.py +1 -1
  5. {compair_core-0.3.7.dist-info → compair_core-0.3.9.dist-info}/METADATA +2 -2
  6. compair_core-0.3.9.dist-info/RECORD +38 -0
  7. compair_core-0.3.9.dist-info/top_level.txt +1 -0
  8. compair_core-0.3.7.dist-info/RECORD +0 -36
  9. compair_core-0.3.7.dist-info/top_level.txt +0 -3
  10. {compair → compair_core/compair}/celery_app.py +0 -0
  11. {compair → compair_core/compair}/default_groups.py +0 -0
  12. {compair → compair_core/compair}/embeddings.py +0 -0
  13. {compair → compair_core/compair}/feedback.py +0 -0
  14. {compair → compair_core/compair}/logger.py +0 -0
  15. {compair → compair_core/compair}/main.py +0 -0
  16. {compair → compair_core/compair}/models.py +0 -0
  17. {compair → compair_core/compair}/schema.py +0 -0
  18. {compair → compair_core/compair}/tasks.py +0 -0
  19. {compair → compair_core/compair}/utils.py +0 -0
  20. {compair_email → compair_core/compair_email}/__init__.py +0 -0
  21. {compair_email → compair_core/compair_email}/email.py +0 -0
  22. {compair_email → compair_core/compair_email}/email_core.py +0 -0
  23. {compair_email → compair_core/compair_email}/templates.py +0 -0
  24. {compair_email → compair_core/compair_email}/templates_core.py +0 -0
  25. {server → compair_core/server}/__init__.py +0 -0
  26. {server → compair_core/server}/deps.py +0 -0
  27. {server → compair_core/server}/local_model/__init__.py +0 -0
  28. {server → compair_core/server}/local_model/app.py +0 -0
  29. {server → compair_core/server}/providers/__init__.py +0 -0
  30. {server → compair_core/server}/providers/console_mailer.py +0 -0
  31. {server → compair_core/server}/providers/contracts.py +0 -0
  32. {server → compair_core/server}/providers/local_storage.py +0 -0
  33. {server → compair_core/server}/providers/noop_analytics.py +0 -0
  34. {server → compair_core/server}/providers/noop_billing.py +0 -0
  35. {server → compair_core/server}/providers/noop_ocr.py +0 -0
  36. {server → compair_core/server}/routers/__init__.py +0 -0
  37. {server → compair_core/server}/routers/capabilities.py +0 -0
  38. {server → compair_core/server}/settings.py +0 -0
  39. {compair_core-0.3.7.dist-info → compair_core-0.3.9.dist-info}/WHEEL +0 -0
  40. {compair_core-0.3.7.dist-info → compair_core-0.3.9.dist-info}/licenses/LICENSE +0 -0
compair_core/api.py ADDED
@@ -0,0 +1,3357 @@
1
+ import hashlib
2
+ import os
3
+ import re
4
+ import requests
5
+ import secrets
6
+ import threading
7
+ import time
8
+ from datetime import datetime, timedelta, timezone
9
+ from typing import Any, Mapping, Optional, Tuple
10
+
11
+ import httpx
12
+ import psutil
13
+ from celery.result import AsyncResult
14
+ from fastapi import APIRouter, Body, Depends, File, Form, Header, HTTPException, Query, Request, UploadFile
15
+ from fastapi.responses import HTMLResponse, RedirectResponse, StreamingResponse
16
+ from sqlalchemy import distinct, func, select, or_
17
+ from sqlalchemy.orm import joinedload, Session
18
+
19
+ from .server.deps import get_analytics, get_billing, get_ocr, get_settings_dependency, get_storage
20
+ from .server.providers.contracts import Analytics, BillingProvider, OCRProvider, StorageProvider
21
+ from .server.settings import Settings
22
+
23
+ from . import compair
24
+ from .compair import models, schema
25
+ from .compair.embeddings import create_embedding, Embedder
26
+ from .compair.logger import log_event
27
+ from .compair.utils import chunk_text, generate_verification_token, log_activity
28
+ from .compair_email.email import emailer, EMAIL_USER
29
+ from .compair_email.templates import (
30
+ ACCOUNT_VERIFY_TEMPLATE,
31
+ GROUP_INVITATION_TEMPLATE,
32
+ GROUP_JOIN_TEMPLATE,
33
+ INDIVIDUAL_INVITATION_TEMPLATE,
34
+ PASSWORD_RESET_TEMPLATE,
35
+ REFERRAL_CREDIT_TEMPLATE
36
+ )
37
+ from .compair.tasks import process_document_task as process_document_celery, send_feature_announcement_task, send_deactivate_request_email, send_help_request_email
38
+
39
+ import redis
40
+
41
+ redis_url = os.environ.get("REDIS_URL")
42
+ redis_client = redis.Redis.from_url(redis_url)
43
+ #from compair.main import process_document
44
+
45
+ router = APIRouter()
46
+ WEB_URL = os.environ.get("WEB_URL")
47
+ ADMIN_API_KEY = os.environ.get("ADMIN_API_KEY")
48
+
49
+ CLOUDFLARE_IMAGES_ACCOUNT = os.environ.get("CLOUDFLARE_IMAGES_ACCOUNT")
50
+ CLOUDFLARE_IMAGES_URL_ACCOUNT = os.environ.get("CLOUDFLARE_IMAGES_URL_ACCOUNT")
51
+ CLOUDFLARE_IMAGES_BASE_URL = f"https://imagedelivery.net/{CLOUDFLARE_IMAGES_URL_ACCOUNT}"
52
+ CLOUDFLARE_IMAGES_TOKEN = os.environ.get("CLOUDFLARE_IMAGES_TOKEN")
53
+ CLOUDFLARE_IMAGES_UPLOAD_URL = f"https://api.cloudflare.com/client/v4/accounts/{CLOUDFLARE_IMAGES_ACCOUNT}/images/v1"
54
+
55
+ GA4_MEASUREMENT_ID = os.getenv("GA4_MEASUREMENT_ID")
56
+ GA4_API_SECRET = os.getenv("GA4_API_SECRET")
57
+
58
+ IS_CLOUD = os.getenv("COMPAIR_EDITION", "core").lower() == "cloud"
59
+
60
+
61
+ def require_cloud(feature: str) -> None:
62
+ if not IS_CLOUD:
63
+ raise HTTPException(status_code=501, detail=f"{feature} is only available in the Compair Cloud edition.")
64
+
65
+
66
+ def _user_plan(user: models.User) -> str:
67
+ return getattr(user, "plan", "free") or "free"
68
+
69
+
70
+ def _user_team(user: models.User):
71
+ return getattr(user, "team", None)
72
+
73
+
74
+ def _trial_expiration(user: models.User) -> datetime | None:
75
+ return getattr(user, "trial_expiration_date", None)
76
+
77
+
78
+ HAS_TEAM = hasattr(models, "Team")
79
+ HAS_ACTIVITY = hasattr(models, "Activity")
80
+ HAS_REFERRALS = hasattr(models.User, "referral_code")
81
+ HAS_BILLING = hasattr(models.User, "stripe_customer_id")
82
+ HAS_TRIALS = hasattr(models.User, "trial_expiration_date")
83
+
84
+
85
+ def require_feature(flag: bool, feature: str) -> None:
86
+ if not flag:
87
+ raise HTTPException(status_code=501, detail=f"{feature} is only available in the Compair Cloud edition.")
88
+
89
+ def get_current_user(auth_token: str = Header(...)):
90
+ with compair.Session() as session:
91
+ user_session = session.query(models.Session).filter(models.Session.id == auth_token).first()
92
+ if not user_session:
93
+ raise HTTPException(status_code=401, detail="Invalid or expired session token")
94
+ # Ensure datetime_valid_until is timezone-aware
95
+ valid_until = user_session.datetime_valid_until
96
+ if valid_until.tzinfo is None:
97
+ valid_until = valid_until.replace(tzinfo=timezone.utc)
98
+ if valid_until < datetime.now(timezone.utc):
99
+ raise HTTPException(status_code=401, detail="Invalid or expired session token")
100
+ user = session.query(
101
+ models.User
102
+ ).filter(
103
+ models.User.user_id == user_session.user_id
104
+ ).options(
105
+ joinedload(models.User.groups)
106
+ ).first()
107
+ if not user:
108
+ raise HTTPException(status_code=404, detail="User not found")
109
+ return user
110
+
111
+ def get_current_user_with_access_to_doc(
112
+ document_id: str,
113
+ current_user: models.User = Depends(get_current_user)
114
+ ) -> models.User:
115
+ with compair.Session() as session:
116
+ doc = session.query(models.Document).filter(models.Document.document_id == document_id).first()
117
+ if not doc:
118
+ raise HTTPException(status_code=404, detail="Document not found")
119
+ # Allow if user is author
120
+ if doc.author_id == current_user.user_id:
121
+ return current_user
122
+ # Allow if user is in any group that has access to the document
123
+ doc_group_ids = {g.group_id for g in doc.groups}
124
+ user_group_ids = {g.group_id for g in current_user.groups}
125
+ if doc_group_ids & user_group_ids:
126
+ return current_user
127
+ # Optionally, allow if document is published
128
+ if doc.is_published:
129
+ return current_user
130
+ raise HTTPException(status_code=403, detail="Not authorized to access this document")
131
+
132
+ def log_service_resource_metrics(service_name="backend"):
133
+ def log():
134
+ try:
135
+ p = psutil.Process(os.getpid())
136
+ mem_mb = round(p.memory_info().rss / 1024 / 1024, 2)
137
+ cpu_percent = p.cpu_percent(interval=1)
138
+ log_event("service_resource", service=service_name, memory_mb=mem_mb, cpu_percent=cpu_percent)
139
+ except Exception as e:
140
+ print(f"[Resource Log Error] {e}")
141
+ finally:
142
+ # Re-schedule logging in 5 minutes
143
+ threading.Timer(300, log).start()
144
+ log()
145
+
146
+ log_service_resource_metrics(service_name="backend") # or "frontend"
147
+
148
+ # Run via: fastapi dev api.py
149
+
150
+
151
+ @router.post("/login")
152
+ def login(request: schema.LoginRequest) -> dict:
153
+ with compair.Session() as session:
154
+ user = session.query(models.User).filter(models.User.username == request.username).first()
155
+ print("PW yo: {request.password}")
156
+ if not user or not user.check_password(request.password):
157
+ raise HTTPException(status_code=401, detail="Invalid credentials")
158
+ if user.status == 'inactive':
159
+ raise HTTPException(status_code=403, detail="User account is not verified")
160
+ now = datetime.now(tz=timezone.utc)
161
+ user_session = models.Session(
162
+ id=secrets.token_urlsafe(),
163
+ user_id=user.user_id,
164
+ datetime_created=now,
165
+ datetime_valid_until=now + timedelta(days=1),
166
+ )
167
+ session.add(user_session)
168
+ session.commit()
169
+ return {
170
+ "user_id": user.user_id,
171
+ "username": user.username,
172
+ "name": user.name,
173
+ "status": user.status,
174
+ "role": user.role,
175
+ "auth_token": user_session.id, # Return the session token here
176
+ }
177
+
178
+
179
+ @router.get("/username_exists")
180
+ def username_exists(username: str) -> dict:
181
+ with compair.Session() as session:
182
+ exists = session.query(models.User).filter(models.User.username == username).first() is not None
183
+ return {"exists": exists}
184
+
185
+
186
+ @router.get("/load_user")
187
+ def load_user(
188
+ username: str,
189
+ ) -> schema.User | None:
190
+ with compair.Session() as session:
191
+ q = select(models.User).filter(
192
+ models.User.username.match(username)
193
+ )
194
+ user = session.execute(q).fetchone()
195
+ if user is None:
196
+ return
197
+ user = user[0]
198
+ if user.groups is None:
199
+ user.groups = []
200
+ print(f'User: {user}')
201
+ return user
202
+
203
+
204
+ @router.get("/load_user_plan")
205
+ def load_user(
206
+ current_user: models.User = Depends(get_current_user)
207
+ ) -> dict:
208
+ return {'plan': _user_plan(current_user)}
209
+
210
+
211
+ @router.get("/load_user_by_id")
212
+ def load_user(
213
+ user_id: str,
214
+ ) -> schema.User | None:
215
+ with compair.Session() as session:
216
+ q = select(models.User).filter(
217
+ models.User.user_id==user_id
218
+ )
219
+ user = session.execute(q).fetchone()
220
+ if user is None:
221
+ return
222
+ user = user[0]
223
+ if user.groups is None:
224
+ user.groups = []
225
+ print(f'User: {user}')
226
+ return user
227
+
228
+
229
+ @router.get("/load_user_files")
230
+ def load_user_files(
231
+ connection_id: str,
232
+ page: int = 1,
233
+ page_size: int = 10,
234
+ filter_type: str | None = None,
235
+ current_user: models.User = Depends(get_current_user)
236
+ ) -> dict:
237
+ with compair.Session() as session:
238
+ # Validate connection
239
+ connection = session.query(models.User).filter(models.User.user_id == connection_id).first()
240
+ if not connection:
241
+ return {"files": [], "message": "User or connection not found."}
242
+
243
+ # Security: must share a group
244
+ shared_group_ids = set(g.group_id for g in current_user.groups).intersection(g.group_id for g in connection.groups)
245
+ if not shared_group_ids:
246
+ return {"files": [], "message": "You do not have permission to view this user's affiliations, as you do not share a group together."}
247
+
248
+ # Privacy: check setting
249
+ if connection.hide_affiliations:
250
+ return {"files": [], "message": f"{connection.username} has set their profile to private."}
251
+
252
+ now = datetime.now(timezone.utc)
253
+ week_ago = now - timedelta(days=7)
254
+ # Fetch documents belonging to shared or public groups
255
+ q = select(models.Document).join(models.Document.groups).filter(
256
+ models.Document.user_id == connection_id,
257
+ models.Group.group_id.in_(shared_group_ids) | # Shared groups
258
+ (
259
+ (models.Group.visibility == "public") & # Public groups
260
+ models.Group.users.any(models.User.user_id == connection_id) # Associated with the connection
261
+ )
262
+ ).options(
263
+ joinedload(models.Document.groups),
264
+ joinedload(models.Document.user).joinedload(models.User.groups)
265
+ )
266
+
267
+ # --- Filter logic ---
268
+ if filter_type == "published":
269
+ q = q.filter(models.Document.is_published == True)
270
+ elif filter_type == "unpublished":
271
+ q = q.filter(models.Document.is_published == False)
272
+ elif filter_type == "recently_updated":
273
+ # Documents updated OR with a note in the past week
274
+ q = q.outerjoin(models.Document.notes).filter(
275
+ or_(
276
+ models.Document.datetime_modified >= week_ago,
277
+ models.Note.datetime_created >= week_ago
278
+ )
279
+ )
280
+ elif filter_type == "recently_compaired":
281
+ # Documents with feedback in the past week
282
+ q = q.join(models.Document.chunks).join(models.Chunk.feedbacks).filter(
283
+ models.Feedback.timestamp >= week_ago
284
+ )
285
+ # Default: all, sorted by last update
286
+ q = q.order_by(models.Document.datetime_modified.desc())
287
+
288
+ documents = session.execute(q).unique().fetchall()
289
+
290
+ if documents is None or len(documents)==0:
291
+ return {
292
+ "documents": [],
293
+ "total_count": 0,
294
+ "message": None
295
+ }
296
+
297
+ total_count = len(documents)
298
+ # Paging
299
+ offset = (page - 1) * page_size
300
+ documents = session.execute(q.order_by(models.Document.datetime_created.desc()).offset(offset).limit(page_size)).unique().fetchall()
301
+
302
+ files = [d[0] for d in documents] if documents else []
303
+ return {
304
+ "files": [schema.Document.model_validate(f) for f in files],
305
+ "message": None,
306
+ "total_count": total_count,
307
+ }
308
+
309
+
310
+ @router.get("/load_user_groups")
311
+ def load_user_groups(
312
+ connection_id: str,
313
+ page: int = 1,
314
+ page_size: int = 10,
315
+ filter_type: str | None = None,
316
+ current_user: models.User = Depends(get_current_user)
317
+ ) -> dict:
318
+ with compair.Session() as session:
319
+ now = datetime.now(timezone.utc)
320
+ week_ago = now - timedelta(days=7)
321
+ # Validate connection
322
+ connection = session.query(models.User).filter(models.User.user_id == connection_id).first()
323
+ if not connection:
324
+ return {"groups": [], "message": "User or connection not found."}
325
+
326
+ # Check if the connection is valid
327
+ shared_group_ids = set([g.group_id for g in current_user.groups]).intersection([g.group_id for g in connection.groups])
328
+ if not shared_group_ids:
329
+ return {"groups": [], "message": "You do not have permission to view this user's affiliations, as you do not share a group together."}
330
+
331
+ # Privacy: check setting
332
+ if connection.hide_affiliations:
333
+ return {"groups": [], "message": f"{connection.username} has set their profile to private."}
334
+
335
+ # Fetch shared or public groups of the connection
336
+ groups_query = session.query(models.Group).filter(
337
+ models.Group.group_id.in_(shared_group_ids) | # Shared groups
338
+ (
339
+ (models.Group.visibility == "public") & # Public groups
340
+ models.Group.users.any(models.User.user_id == connection_id) # Associated with the connection
341
+ )
342
+ ).options(
343
+ joinedload(models.Group.users),
344
+ joinedload(models.Group.documents)
345
+ )
346
+
347
+ # --- Filter logic (match /load_groups) ---
348
+ if filter_type == "internal":
349
+ groups_query = groups_query.filter(models.Group.visibility == "internal")
350
+ elif filter_type == "public":
351
+ groups_query = groups_query.filter(models.Group.visibility == "public")
352
+ elif filter_type == "private":
353
+ groups_query = groups_query.filter(models.Group.visibility == "private")
354
+ elif filter_type == "recently_updated":
355
+ groups_query = groups_query.join(models.Group.documents).filter(
356
+ models.Document.datetime_created >= week_ago
357
+ )
358
+ else:
359
+ groups_query = groups_query.order_by(models.Group.name.asc())
360
+
361
+ total_count = groups_query.count()
362
+ offset = (page - 1) * page_size
363
+ groups = groups_query.offset(offset).limit(page_size).all()
364
+
365
+ result = [
366
+ {
367
+ "group_id": group.group_id,
368
+ "name": group.name,
369
+ "datetime_created": group.datetime_created,
370
+ "group_image": group.group_image,
371
+ "category": group.category,
372
+ "description": group.description,
373
+ "visibility": group.visibility,
374
+ "document_count": getattr(group, "document_count", None),
375
+ "user_count": getattr(group, "user_count", None),
376
+ "first_three_user_profile_images": getattr(group, "first_three_user_profile_images", None)
377
+ }
378
+ for group in groups
379
+ ]
380
+ return {"groups": result, "message": None, "total_count": total_count}
381
+
382
+
383
+ @router.get("/load_user_status")
384
+ def load_user_status(
385
+ current_user: models.User = Depends(get_current_user)
386
+ ) -> str:
387
+ user_status = 'inactive'
388
+ with compair.Session() as session:
389
+ user_status = current_user.status
390
+ return user_status
391
+
392
+
393
+ @router.get("/load_user_status_date")
394
+ def load_user_status(
395
+ current_user: models.User = Depends(get_current_user)
396
+ ) -> datetime:
397
+ if not (HAS_TRIALS or HAS_BILLING):
398
+ raise HTTPException(status_code=501, detail="User status dates are only tracked in the Compair Cloud edition.")
399
+ with compair.Session() as session:
400
+ user_status = current_user.status
401
+ if user_status=='active':
402
+ require_feature(HAS_BILLING, "Billing history")
403
+ user_status_date = current_user.last_payment_date
404
+ elif user_status=='trial':
405
+ require_feature(HAS_TRIALS, "Trial management")
406
+ user_status_date = current_user.trial_expiration_date
407
+ elif user_status=='suspended':
408
+ user_status_date = current_user.status_change_date
409
+ else:
410
+ raise HTTPException(status_code=403, detail='User Inactive')
411
+ return user_status_date
412
+
413
+
414
+ @router.get("/load_referral_credits")
415
+ def load_referral_credits(
416
+ current_user: models.User = Depends(get_current_user)
417
+ ) -> Tuple[int, int]:
418
+ if not HAS_REFERRALS:
419
+ raise HTTPException(status_code=501, detail="Referral credits are only available in the Compair Cloud edition.")
420
+ with compair.Session() as session:
421
+ referral_credits_earned = current_user.referral_credits
422
+ referral_credits_pending = current_user.pending_referral_credits
423
+ return (referral_credits_earned, referral_credits_pending)
424
+
425
+
426
+ def create_user(
427
+ username: str,
428
+ name: str,
429
+ password: str,
430
+ session: Session,
431
+ groups: list[str] | None = None,
432
+ referral_code: str = None
433
+ ):
434
+ token, expiration = generate_verification_token()
435
+ existing_user = session.query(models.User).filter(models.User.username == username).first()
436
+ if existing_user:
437
+ raise HTTPException(status_code=400, detail="Email already in use")
438
+
439
+ user = models.User(
440
+ username=username,
441
+ name=name,
442
+ datetime_registered=datetime.now(),
443
+ verification_token=token.lower(),
444
+ token_expiration=expiration,
445
+ )
446
+ user.set_password(password=password)
447
+ session.add(user)
448
+ session.commit()
449
+
450
+ if HAS_TEAM:
451
+ team_invitations = session.query(models.TeamInvitation).filter(
452
+ models.TeamInvitation.email == username,
453
+ models.TeamInvitation.status == "pending",
454
+ ).all()
455
+
456
+ for invitation in team_invitations:
457
+ emailer.connect()
458
+ emailer.send(
459
+ subject="You're Invited to Join a Team on Compair",
460
+ sender=EMAIL_USER,
461
+ receivers=[invitation.email],
462
+ html=f"""
463
+ <p>{invitation.inviter.name} has invited you to join their team on Compair!</p>
464
+ <p>Click <a href="https://{WEB_URL}/accept-invitation?token={invitation.invitation_id}">here</a> to join.</p>
465
+ """
466
+ )
467
+ invitation.status = "sent"
468
+ session.commit()
469
+
470
+ if groups is not None:
471
+ user.groups = load_groups_by_ids(groups)
472
+ else:
473
+ # Add the welcome group and default user group
474
+ welcome_group = session.query(models.Group).filter(
475
+ models.Group.name == "Welcome to Compair"
476
+ ).first()
477
+ welcome_doc_id = create_doc(
478
+ authorid=user.user_id,
479
+ document_title=f'Introducing {user.name} to Compair',
480
+ document_type='txt',
481
+ document_content="",
482
+ groups=welcome_group.group_id,
483
+ is_published=False,
484
+ current_user=user
485
+ )['document_id']
486
+ welcome_doc = session.query(models.Document).filter(
487
+ models.Document.document_id == welcome_doc_id
488
+ ).first()
489
+ group = models.Group(
490
+ name=username,
491
+ datetime_created=datetime.now(),
492
+ group_image=None,
493
+ category="Private",
494
+ description=f"A private group for {username}",
495
+ visibility="private"
496
+ )
497
+ admin = models.Administrator(user_id=user.user_id)
498
+ session.add(admin)
499
+ session.add(group)
500
+ session.add(welcome_doc)
501
+ session.commit()
502
+ group.admins.append(admin)
503
+ user.groups = [group, welcome_group]
504
+
505
+ # Track referral if a code is provided
506
+ if referral_code:
507
+ require_feature(HAS_REFERRALS, "Referral program")
508
+ print(f'Got to backend referral code: {referral_code}')
509
+ referrer = session.query(models.User).filter(models.User.referral_code == referral_code).first()
510
+ if referrer and hasattr(referrer, "referral_credits"):
511
+ max_credits = 3
512
+ if referrer.referral_credits <= max_credits * 10: # $10 per credit
513
+ if hasattr(user, "referred_by"):
514
+ user.referred_by = referrer.user_id # Store who referred them
515
+ if hasattr(referrer, "pending_referral_credits"):
516
+ referrer.pending_referral_credits += 10 # Add pending credit
517
+ session.commit()
518
+
519
+ session.add(user)
520
+ session.commit()
521
+ try:
522
+ analytics.track("user_signup", user.user_id)
523
+ except Exception as exc:
524
+ print(f"analytics track failed: {exc}")
525
+ return user
526
+
527
+
528
+ @router.get("/load_session")
529
+ def load_session(auth_token: str) -> schema.Session | None:
530
+ with compair.Session() as session:
531
+ user_session = session.query(models.Session).filter(models.Session.id == auth_token).first()
532
+ if not user_session:
533
+ raise HTTPException(status_code=404, detail="Session not found")
534
+ valid_until = user_session.datetime_valid_until
535
+ if valid_until.tzinfo is None:
536
+ valid_until = valid_until.replace(tzinfo=timezone.utc)
537
+ if valid_until < datetime.now(timezone.utc):
538
+ raise HTTPException(status_code=401, detail="Invalid or expired session token")
539
+ return user_session
540
+
541
+
542
+ @router.post("/update_user")
543
+ def update_user(
544
+ name: str = Form(None),
545
+ role: str = Form(None),
546
+ group_ids: list[str] = Form(None),
547
+ include_own_documents_in_feedback: str = Form(None),
548
+ default_publish: str = Form(None),
549
+ preferred_feedback_length: str = Form(None),
550
+ hide_affiliations: str = Form(None),
551
+ current_user: models.User = Depends(get_current_user)
552
+ ):
553
+ with compair.Session() as session:
554
+ if name is not None:
555
+ current_user.name = name
556
+ if role is not None:
557
+ current_user.role = role
558
+ if group_ids is not None:
559
+ groups = load_groups_by_ids(group_ids)
560
+ groups = [g for g in groups if g not in current_user.groups]
561
+ current_user.groups.extend(groups)
562
+ if include_own_documents_in_feedback is not None:
563
+ # Convert string to bool
564
+ current_user.include_own_documents_in_feedback = include_own_documents_in_feedback.lower() == "true"
565
+ if default_publish is not None:
566
+ current_user.default_publish = default_publish.lower() == "true"
567
+ if preferred_feedback_length is not None:
568
+ # Lock to Brief for the time being
569
+ preferred_feedback_length = 'Brief'
570
+ current_user.preferred_feedback_length = preferred_feedback_length
571
+ if hide_affiliations is not None:
572
+ current_user.hide_affiliations = hide_affiliations.lower() == "true"
573
+ session.add(current_user)
574
+ session.commit()
575
+
576
+ @router.get("/update_session_duration")
577
+ def update_session_duration(
578
+ user_session: schema.Session,
579
+ new_valid_until: datetime,
580
+ ) -> None:
581
+ with compair.Session() as session:
582
+ user_session = session.query(
583
+ models.Session
584
+ ).filter(
585
+ models.Session.id == user_session.id
586
+ ).first()
587
+ user_session.update({'datetime_valid_until': new_valid_until})
588
+ session.commit()
589
+
590
+
591
+ @router.get("/delete_user")
592
+ def delete_user(
593
+ current_user: models.User = Depends(get_current_user)
594
+ ):
595
+ with compair.Session() as session:
596
+ current_user.delete()
597
+ session.commit()
598
+
599
+
600
+ @router.get("/load_connections")
601
+ def load_connections(
602
+ page: int = 1,
603
+ page_size: int = 10,
604
+ filter_type: str | None = None,
605
+ current_user: models.User = Depends(get_current_user)
606
+ ) -> dict | None:
607
+ with compair.Session() as session:
608
+ now = datetime.now(timezone.utc)
609
+ week_ago = now - timedelta(days=7)
610
+
611
+ # Get all groups the user belongs to
612
+ groups = session.query(models.Group).options(joinedload(models.Group.users)).filter(
613
+ models.Group.group_id.in_([g.group_id for g in current_user.groups])
614
+ ).all()
615
+ if not groups:
616
+ return {"connections": [], "total_count": 0}
617
+
618
+ # Collect all user IDs from the groups
619
+ connection_ids = set()
620
+ for group in groups:
621
+ for group_user in group.users:
622
+ if group_user.user_id != current_user.user_id: # Exclude the requesting user
623
+ connection_ids.add(group_user.user_id)
624
+
625
+ # Fetch the User objects for the collected IDs
626
+ q = session.query(models.User).filter(models.User.user_id.in_(connection_ids))
627
+
628
+ # --- Filter logic ---
629
+ if filter_type == "recently_active":
630
+ q = q.join(models.User.activities).filter(
631
+ models.Activity.action == "create",
632
+ models.Activity.timestamp >= week_ago
633
+ )
634
+ elif filter_type == "recently_compaired":
635
+ # Their doc was used for feedback on your doc, or vice versa, in the past week
636
+ q = q.join(models.User.documents).join(models.Document.chunks).join(models.Chunk.feedbacks).filter(
637
+ models.Feedback.timestamp >= week_ago
638
+ )
639
+ else:
640
+ q = q.order_by(models.User.name.asc())
641
+
642
+ total_count = q.count()
643
+ offset = (page - 1) * page_size
644
+ connections = q.order_by(models.User.datetime_registered.desc()).offset(offset).limit(page_size).all()
645
+
646
+ # Convert users to dictionary format
647
+ return {
648
+ "connections": [
649
+ {
650
+ "user_id": connection.user_id,
651
+ "username": connection.username,
652
+ "name": connection.name,
653
+ "datetime_registered": connection.datetime_registered,
654
+ "status": connection.status,
655
+ "profile_image": connection.profile_image,
656
+ "role": connection.role,
657
+ }
658
+ for connection in connections
659
+ ],
660
+ "total_count": total_count
661
+ }
662
+
663
+
664
+ @router.get("/all_group_categories")
665
+ def all_group_categories(
666
+ current_user: models.User = Depends(get_current_user)
667
+ ):
668
+ """Return all unique group categories."""
669
+ with compair.Session() as session:
670
+ categories = (
671
+ session.query(distinct(models.Group.category))
672
+ .order_by(models.Group.category.asc())
673
+ .all()
674
+ )
675
+ # Flatten list of tuples and filter out None/empty
676
+ categories = [c[0] for c in categories if c[0] and c[0]!='Compair']
677
+ return {"categories": categories}
678
+
679
+
680
+ @router.get("/load_groups")
681
+ def load_groups(
682
+ user_id: str | None = None,
683
+ page: int = 1,
684
+ page_size: int = 10,
685
+ filter_type: str | None = None,
686
+ category: str | None = None,
687
+ visibility: str | None = None,
688
+ sort: str | None = None,
689
+ query: str | None = None,
690
+ own_groups_only: bool = False,
691
+ current_user: models.User = Depends(get_current_user)
692
+ ) -> dict | None:
693
+ with compair.Session() as session:
694
+ # --- User-based group selection ---
695
+ if user_id is None:
696
+ q = session.query(models.Group).options(
697
+ joinedload(models.Group.users),
698
+ joinedload(models.Group.documents)
699
+ ).filter(
700
+ models.Group.visibility != 'private'
701
+ )
702
+ else:
703
+ user = session.query(models.User).options(
704
+ joinedload(models.User.groups).joinedload(models.Group.users),
705
+ joinedload(models.User.groups).joinedload(models.Group.documents)
706
+ ).filter(
707
+ models.User.user_id == current_user.user_id
708
+ ).first()
709
+ user_group_ids = [g.group_id for g in user.groups if g.category!='Compair']
710
+
711
+ invited_group_ids = set()
712
+ invitations = session.query(models.GroupInvitation).filter(
713
+ models.GroupInvitation.email == user.username,
714
+ models.GroupInvitation.status == "sent"
715
+ ).all()
716
+ invited_group_ids.update([i.group_id for i in invitations])
717
+
718
+ accessible_group_ids = set(user_group_ids) | invited_group_ids
719
+
720
+ if own_groups_only:
721
+ # Only groups the user is a member of or invited
722
+ q = session.query(models.Group).filter(
723
+ models.Group.group_id.in_(accessible_group_ids)
724
+ )
725
+ else:
726
+ # All groups user can access: public, or private/internal if a member, or groups with an invitation
727
+ q = session.query(models.Group).filter(
728
+ (models.Group.visibility == "public") |
729
+ (models.Group.group_id.in_(accessible_group_ids))
730
+ )
731
+
732
+ # --- Filtering ---
733
+ if category and category.lower() != "all":
734
+ q = q.filter(models.Group.category == category)
735
+ if visibility and visibility.lower() != "all":
736
+ q = q.filter(models.Group.visibility == visibility)
737
+ if filter_type == "joined" and user_id:
738
+ #user = session.query(models.User).filter(models.User.user_id == user_id).first()
739
+ #q = q.filter(models.Group.group_id.in_([g.group_id for g in user.groups]))
740
+ # Already constrained above if own_groups_only is True
741
+ pass
742
+ if filter_type == "pending" and user_id:
743
+ # Groups with pending join requests or invitations for this user
744
+ user = session.query(models.User).filter(models.User.user_id == current_user.user_id).first()
745
+ pending_group_ids = set()
746
+ invitations = session.query(models.GroupInvitation).filter(
747
+ models.GroupInvitation.email == user.username,
748
+ models.GroupInvitation.status == "sent"
749
+ ).all()
750
+ pending_group_ids.update([i.group_id for i in invitations])
751
+ q = q.filter(models.Group.group_id.in_(pending_group_ids))
752
+ if query:
753
+ q = q.filter(models.Group.name.ilike(f"%{query}%"))
754
+ # --- Sorting ---
755
+ if sort == "popular":
756
+ q = q.outerjoin(models.Group.users).group_by(models.Group.group_id).order_by(func.count(models.User.user_id).desc())
757
+ elif sort == "recently_updated":
758
+ q = q.outerjoin(models.Group.documents).group_by(models.Group.group_id).order_by(func.max(models.Document.datetime_modified).desc())
759
+ elif sort == "recently_created":
760
+ q = q.order_by(models.Group.datetime_created.desc())
761
+ else:
762
+ q = q.order_by(models.Group.name.asc())
763
+ # --- Paging ---
764
+ total_count = q.count()
765
+ offset = (page - 1) * page_size
766
+ groups = q.offset(offset).limit(page_size).all()
767
+
768
+ result = [
769
+ {
770
+ "group_id": group.group_id,
771
+ "name": group.name,
772
+ "datetime_created": group.datetime_created,
773
+ "group_image": group.group_image,
774
+ "category": group.category,
775
+ "description": group.description,
776
+ "visibility": group.visibility,
777
+ "document_count": group.document_count,
778
+ "user_count": group.user_count,
779
+ "first_three_user_profile_images": group.first_three_user_profile_images
780
+ }
781
+ for group in groups
782
+ ]
783
+ return {"groups": result, "total_count": total_count}
784
+
785
+
786
+ def load_groups_by_ids(group_ids: list[str]) -> list[schema.Group]:
787
+ with compair.Session() as session:
788
+ q = session.query(models.Group).filter(
789
+ models.Group.group_id.in_(group_ids)
790
+ )
791
+ return q.all() # Returns list of Group objects directly
792
+
793
+
794
+ @router.get("/load_group")
795
+ def load_group(
796
+ name: str | None = None,
797
+ group_id: str | None = None
798
+ ) -> schema.Group | None:
799
+ if (name is not None) or (group_id is not None):
800
+ with compair.Session() as session:
801
+ if group_id is not None:
802
+ q = select(models.Group).filter(
803
+ models.Group.group_id==group_id
804
+ )
805
+ else:
806
+ q = select(models.Group).filter(
807
+ models.Group.name.match(name)
808
+ )
809
+ group = session.execute(q).fetchone()
810
+ if group is None:
811
+ return None
812
+ return group[0]
813
+
814
+
815
+ def notify_group_admins(
816
+ group: models.Group,
817
+ user_id: str
818
+ ):
819
+ """Send an email notification to group admins."""
820
+ with compair.Session() as session:
821
+ admin_emails = [admin.user.username for admin in group.admins]
822
+ if len(admin_emails) == 0:
823
+ print("No admins found for group:", group.name)
824
+ return
825
+ user = session.query(models.User).filter(models.User.user_id == user_id).first()
826
+ emailer.connect()
827
+ print("Admin emails:", admin_emails)
828
+ emailer.send(
829
+ subject="Group Join Request",
830
+ sender=EMAIL_USER,
831
+ receivers=admin_emails,
832
+ html=GROUP_JOIN_TEMPLATE.replace(
833
+ "{{ user_name }}", user.username
834
+ ).replace(
835
+ "{{ group_name }}", group.name
836
+ ).replace(
837
+ "{{ admin_panel_url }}", f"http://{WEB_URL}/admin/groups"
838
+ )
839
+ )
840
+
841
+ @router.post("/join_group")
842
+ def join_group(
843
+ group_id: str = Form(...),
844
+ current_user: models.User = Depends(get_current_user)
845
+ ):
846
+ return join_group_direct(
847
+ user_id=current_user.user_id,
848
+ group_id=group_id
849
+ )
850
+
851
+ def join_group_direct(
852
+ user_id: str,
853
+ group_id: str
854
+ ):
855
+ """Join a group based on its visibility."""
856
+ print("1")
857
+ with compair.Session() as session:
858
+ group = session.query(models.Group).filter(models.Group.group_id == group_id).first()
859
+ print(group)
860
+ if not group:
861
+ raise HTTPException(status_code=404, detail="Group not found")
862
+ print(group.visibility)
863
+ if group.visibility in ["public", "internal"]:
864
+ user = session.query(models.User).filter(models.User.user_id == user_id).first()
865
+ # Look for any existing invitations associated with this group
866
+ invitations = session.query(models.GroupInvitation).filter(
867
+ models.GroupInvitation.group_id == group_id,
868
+ models.GroupInvitation.email == user.username,
869
+ models.GroupInvitation.status == "sent"
870
+ ).all()
871
+ for invitation in invitations:
872
+ invitation.status = "accepted"
873
+
874
+ if group.visibility == "public":
875
+ group.users.append(user)
876
+ session.commit()
877
+ log_activity(
878
+ session=session,
879
+ user_id=user_id,
880
+ group_id=group.group_id,
881
+ action="join",
882
+ object_id=group.group_id,
883
+ object_name=group.name,
884
+ object_type="group"
885
+ )
886
+ return {"message": "Joined group successfully"}
887
+
888
+ elif group.visibility == "internal":
889
+ if len(invitations)>0:
890
+ # Invitation found; add to group
891
+ group.users.append(user)
892
+ session.commit()
893
+ else:
894
+ # Create a JoinRequest if not already present
895
+ existing_request = session.query(models.JoinRequest).filter(
896
+ models.JoinRequest.user_id == user_id,
897
+ models.JoinRequest.group_id == group_id
898
+ ).first()
899
+ if not existing_request:
900
+ join_request = models.JoinRequest(
901
+ user_id=user_id,
902
+ group_id=group_id,
903
+ datetime_requested=datetime.now(timezone.utc)
904
+ )
905
+ session.add(join_request)
906
+ session.commit()
907
+ notify_group_admins(group, user_id)
908
+ return {"message": "Join request sent to group admins"}
909
+
910
+ elif group.visibility == "private":
911
+ raise HTTPException(status_code=403, detail="Cannot join private group without an invite")
912
+
913
+
914
+ @router.post("/create_group")
915
+ async def create_group(
916
+ name: str = Form(...),
917
+ category: str = Form(None),
918
+ description: str = Form(None),
919
+ visibility: str = Form("public"),
920
+ file: UploadFile = File(None), # Allow optional file upload
921
+ current_user: models.User = Depends(get_current_user)
922
+ ):
923
+ with compair.Session() as session:
924
+ print('1')
925
+ if category not in all_group_categories()['categories']:
926
+ category = "Other" # Default to "Other" if category is not valid
927
+
928
+ # Limit internal group creation to active, team plans
929
+ if visibility == 'internal' and not (current_user.status == 'active' and _user_plan(current_user) == 'team'):
930
+ raise HTTPException(
931
+ status_code=403,
932
+ detail="Internal groups can only be created by users with an active team plan"
933
+ )
934
+
935
+ created_group = models.Group(
936
+ name=name,
937
+ group_image=None,
938
+ category=category,
939
+ description=description,
940
+ visibility=visibility,
941
+ datetime_created=datetime.now(),
942
+ )
943
+ print(created_group)
944
+ # Check if user has an admin ID
945
+ q = select(models.Administrator).filter(
946
+ models.Administrator.user_id==current_user.user_id
947
+ )
948
+ admin = session.execute(q).fetchone()
949
+ print(admin)
950
+ if admin is None:
951
+ # Make the user an admin
952
+ admin = models.Administrator(user_id=current_user.user_id)
953
+ session.add(admin)
954
+ session.commit()
955
+ else:
956
+ admin = admin[0]
957
+ print('3?')
958
+ created_group.admins.append(admin)
959
+ session.add(created_group)
960
+ session.commit()
961
+
962
+ # Log activity
963
+ log_activity(
964
+ session=session,
965
+ user_id=current_user.user_id,
966
+ group_id=created_group.group_id,
967
+ action="create",
968
+ object_id=created_group.group_id,
969
+ object_name=created_group.name,
970
+ object_type="group"
971
+ )
972
+
973
+ # Add group to user
974
+ admin.user.groups.append(created_group)
975
+ print('4??')
976
+ session.add(admin)
977
+ session.commit()
978
+
979
+ if file is not None:
980
+ await upload_group_image(
981
+ group_id=created_group.group_id,
982
+ upload_type='group',
983
+ file=file
984
+ )
985
+
986
+
987
+ @router.get("/load_documents")
988
+ def load_doc(
989
+ group_id: str | None = None,
990
+ page: int = 1,
991
+ page_size: int = 10,
992
+ filter_type: str | None = None,
993
+ own_documents_only: bool = True,
994
+ current_user: models.User = Depends(get_current_user)
995
+ ) -> Mapping[str, Any] | None:
996
+ with compair.Session() as session:
997
+ now = datetime.now(timezone.utc)
998
+ week_ago = now - timedelta(days=7)
999
+
1000
+ # Get user and their group memberships
1001
+ user_group_ids = set(g.group_id for g in current_user.groups)
1002
+
1003
+ q = session.query(models.Document)
1004
+
1005
+ if own_documents_only:
1006
+ q = q.filter(models.Document.user_id == current_user.user_id)
1007
+ else:
1008
+ q = q.join(models.Document.groups)
1009
+ if group_id:
1010
+ q = q.filter(models.Group.group_id == group_id)
1011
+ q = q.filter(
1012
+ (models.Group.visibility == "public") |
1013
+ (
1014
+ models.Group.group_id.in_(user_group_ids)
1015
+ )
1016
+ ).options(
1017
+ joinedload(models.Document.groups),
1018
+ joinedload(models.Document.user).joinedload(models.User.groups)
1019
+ )
1020
+
1021
+ if group_id is not None and own_documents_only:
1022
+ q = q.filter(models.Document.groups.any(models.Group.group_id == group_id))
1023
+
1024
+ # --- Filter logic: publishing ---
1025
+ if filter_type == "unpublished" and own_documents_only:
1026
+ q = q.filter(models.Document.is_published == False)
1027
+ else:
1028
+ q = q.filter(
1029
+ or_(
1030
+ models.Document.is_published == True,
1031
+ models.Document.user_id == current_user.user_id
1032
+ )
1033
+ )
1034
+
1035
+ # --- Filter logic: other ---
1036
+ if filter_type == "recently_updated":
1037
+ # Documents updated OR with a note in the past week
1038
+ q = q.outerjoin(models.Document.notes).filter(
1039
+ or_(
1040
+ models.Document.datetime_modified >= week_ago,
1041
+ models.Note.datetime_created >= week_ago
1042
+ )
1043
+ )
1044
+ elif filter_type == "recently_compaired":
1045
+ # Documents with feedback in the past week
1046
+ q = q.join(models.Document.chunks).join(models.Chunk.feedbacks).filter(
1047
+ models.Feedback.timestamp >= week_ago
1048
+ )
1049
+ # Default: all, sorted by last update
1050
+ q = q.order_by(models.Document.datetime_modified.desc())
1051
+
1052
+ documents = session.execute(q).unique().fetchall()
1053
+ print(documents)
1054
+ if documents is None or len(documents)==0:
1055
+ return {
1056
+ "documents": [],
1057
+ "total_count": 0
1058
+ }
1059
+
1060
+ total_count = q.count()
1061
+
1062
+ # Paging
1063
+ offset = (page - 1) * page_size
1064
+ documents = session.execute(q.order_by(models.Document.datetime_created.desc()).offset(offset).limit(page_size)).unique().fetchall()
1065
+ #print(documents)
1066
+ if documents is None or len(documents)==0:
1067
+ return {
1068
+ "documents": [],
1069
+ "total_count": 0
1070
+ }
1071
+ #print(f'API returning these documents: {documents}')
1072
+ print(f'Total count: {total_count}')
1073
+ print(f'Page: {page}, Page size: {page_size}, Offset: {offset}')
1074
+ return {
1075
+ "documents": [
1076
+ schema.Document.model_validate(d[0]) for d in documents
1077
+ ],
1078
+ "total_count": total_count
1079
+ }
1080
+
1081
+
1082
+ @router.get("/load_group_users")
1083
+ def load_group_users(
1084
+ group_id: str,
1085
+ page: int = 1,
1086
+ page_size: int = 10,
1087
+ filter_type: str | None = None,
1088
+ current_user: models.User = Depends(get_current_user)
1089
+ ) -> dict:
1090
+ print('1')
1091
+ print(current_user.user_id)
1092
+ print(group_id)
1093
+ with compair.Session() as session:
1094
+ # Check if the group exists
1095
+ group = session.query(models.Group).filter(models.Group.group_id == group_id).first()
1096
+ print(group)
1097
+ if not group:
1098
+ raise HTTPException(status_code=404, detail="Group not found")
1099
+
1100
+ if (current_user not in group.users) & (group.visibility!='public'):
1101
+ raise HTTPException(status_code=403, detail="User does not belong to the group")
1102
+
1103
+ # Retrieve all users associated with the group
1104
+ users_query = session.query(models.User).join(models.User.groups).filter(models.Group.group_id == group_id)
1105
+
1106
+ # --- Filter logic ---
1107
+ if filter_type == "recently_active":
1108
+ users_query = users_query.order_by(models.User.status_change_date.desc())
1109
+ elif filter_type == "recently_joined":
1110
+ users_query = users_query.order_by(models.User.datetime_registered.desc())
1111
+ else:
1112
+ users_query = users_query.order_by(models.User.name.asc())
1113
+
1114
+ total_count = users_query.count()
1115
+ offset = (page - 1) * page_size
1116
+ users = users_query.offset(offset).limit(page_size).all()
1117
+
1118
+ # Convert users to schema objects
1119
+ return {
1120
+ "users": [
1121
+ {
1122
+ "user_id": u.user_id,
1123
+ "username": u.username,
1124
+ "name": u.name,
1125
+ "datetime_registered": u.datetime_registered,
1126
+ "status": u.status,
1127
+ "profile_image": u.profile_image,
1128
+ "role": u.role,
1129
+ }
1130
+ for u in users
1131
+ ],
1132
+ "total_count": total_count
1133
+ }
1134
+
1135
+
1136
+ @router.get("/load_document")
1137
+ def load_doc(
1138
+ title: str,
1139
+ current_user: models.User = Depends(get_current_user)
1140
+ ) -> schema.Document | None:
1141
+ with compair.Session() as session:
1142
+ q = select(models.Document).filter(
1143
+ models.Document.user_id==current_user.user_id
1144
+ ).filter(
1145
+ models.Document.title.match(title)
1146
+ ).options(
1147
+ joinedload(models.Document.groups),
1148
+ joinedload(models.Document.user).joinedload(models.User.groups)
1149
+ )
1150
+ document = session.execute(q).unique().fetchone()
1151
+ if document is None:
1152
+ return None
1153
+ return document[0]
1154
+
1155
+
1156
+ @router.get("/load_document_by_id")
1157
+ def load_doc(
1158
+ document_id: str,
1159
+ current_user: models.User = Depends(get_current_user)
1160
+ ) -> schema.Document | None:
1161
+ with compair.Session() as session:
1162
+ q = select(models.Document).filter(
1163
+ models.Document.document_id==document_id
1164
+ ).options(
1165
+ joinedload(models.Document.groups),
1166
+ joinedload(models.Document.user).joinedload(models.User.groups)
1167
+ )
1168
+ document = session.execute(q).unique().fetchone()
1169
+ if document is None:
1170
+ return None
1171
+ doc = document[0]
1172
+ doc_group_ids = {g.group_id for g in doc.groups}
1173
+ user_group_ids = {g.group_id for g in current_user.groups}
1174
+ if not doc_group_ids & user_group_ids and current_user.user_id != doc.author_id:
1175
+ raise HTTPException(status_code=403, detail="Not authorized to view this document")
1176
+ return doc
1177
+
1178
+
1179
+ @router.post("/update_doc")
1180
+ def update_doc(
1181
+ doc_id: str = Form(...),
1182
+ author_id: str = Form(None),
1183
+ title: str = Form(None),
1184
+ datetime_created: datetime = Form(None),
1185
+ group_ids: list[str] = Form(None),
1186
+ image_url: str = Form(None),
1187
+ current_user: models.User = Depends(get_current_user)
1188
+ ):
1189
+ print('In update doc')
1190
+ print(author_id)
1191
+ print(title)
1192
+ print(datetime_created)
1193
+ print(group_ids)
1194
+ print(image_url)
1195
+ with compair.Session() as session:
1196
+ doc = session.query(models.Document).filter(
1197
+ models.Document.document_id == doc_id,
1198
+ models.Document.user_id == current_user.user_id
1199
+ ).first()
1200
+ if doc:
1201
+ if author_id is not None:
1202
+ doc.author_id = author_id
1203
+ if title is not None:
1204
+ doc.title = title
1205
+ if datetime_created is not None:
1206
+ doc.datetime_created = datetime_created
1207
+ if group_ids is not None:
1208
+ groups = load_groups_by_ids(group_ids)
1209
+ doc.groups = []
1210
+ doc.groups.extend(groups)
1211
+ print(f'New groups here? {doc.groups}')
1212
+ if image_url is not None:
1213
+ doc.image_url = image_url
1214
+ session.commit()
1215
+
1216
+
1217
+ @router.post("/create_doc")
1218
+ def create_doc(
1219
+ authorid: str = Form(None),
1220
+ document_title: str = Form(None),
1221
+ document_type: str = Form(None),
1222
+ document_content: str = Form(""),
1223
+ groups: str = Form(None), # TODO: Fix how these get submitted; current comma-separated list string
1224
+ is_published: bool = Form(False),
1225
+ current_user: models.User = Depends(get_current_user),
1226
+ analytics: Analytics = Depends(get_analytics),
1227
+ ):
1228
+ with compair.Session() as session:
1229
+ # Check if the trial has expired
1230
+ current_user = session.query(models.User).filter(models.User.user_id == current_user.user_id).first()
1231
+ trial_expiration = _trial_expiration(current_user)
1232
+ if HAS_TRIALS and trial_expiration and current_user.status == "trial" and trial_expiration < datetime.now(timezone.utc):
1233
+ current_user.status = "suspended" # Mark as suspended once the trial expires
1234
+ current_user.status_change_date = datetime.now(timezone.utc)
1235
+ session.commit()
1236
+
1237
+ # Enforce document limits
1238
+ team = _user_team(current_user)
1239
+ if IS_CLOUD and HAS_TEAM and team and current_user.status == "active":
1240
+ document_limit = team.total_documents_limit # type: ignore[union-attr]
1241
+ else:
1242
+ if IS_CLOUD and _user_plan(current_user) == 'individual' and current_user.status == "active":
1243
+ document_limit = 100
1244
+ else:
1245
+ document_limit = int(os.getenv("COMPAIR_CORE_DOCUMENT_LIMIT", "10"))
1246
+ document_count = session.query(models.Document).filter(models.Document.user_id == current_user.user_id).count()
1247
+
1248
+ if document_count >= document_limit:
1249
+ raise HTTPException(
1250
+ status_code=403,
1251
+ detail=f"Document limit reached. Individual plan users can have 100, team plans have 100 times the number of users (pooled); other plans can have 10",
1252
+ )
1253
+
1254
+ if not authorid:
1255
+ authorid = current_user.user_id
1256
+
1257
+ document = models.Document(
1258
+ user_id=current_user.user_id,
1259
+ author_id=authorid,
1260
+ title=document_title,
1261
+ content=document_content,
1262
+ doc_type=document_type,
1263
+ datetime_created=datetime.now(timezone.utc),
1264
+ datetime_modified=datetime.now(timezone.utc)
1265
+ )
1266
+ print('About to assign groups!')
1267
+ print(groups.split(','))
1268
+ if groups is not None:
1269
+ group_ids = groups.split(',')
1270
+ q = select(models.Group).filter(
1271
+ models.Group.group_id.in_(group_ids)
1272
+ )
1273
+ groups = session.execute(q).fetchall()[0]
1274
+ for group in groups:
1275
+ document.groups.append(group)
1276
+ else:
1277
+ q = select(models.Group).filter(
1278
+ models.Group.name == current_user.username
1279
+ )
1280
+ group = session.execute(q).fetchone()[0]
1281
+ document.groups = [group]
1282
+
1283
+ print(f'doc check!!! {document.content}')
1284
+ session.add(document)
1285
+ session.commit()
1286
+
1287
+ if is_published:
1288
+ # Attempt to publish doc, pending user status check
1289
+ publish_doc(
1290
+ doc_id=document.document_id,
1291
+ is_published=True,
1292
+ current_user=current_user
1293
+ )
1294
+
1295
+ # Log document creation
1296
+ log_activity(
1297
+ session=session,
1298
+ user_id=document.author_id,
1299
+ group_id=group.group_id,
1300
+ action="create",
1301
+ object_id=document.document_id,
1302
+ object_name=document.title,
1303
+ object_type="document"
1304
+ )
1305
+ # Return the document_id for frontend use
1306
+ try:
1307
+ analytics.track("document_created", document.user_id)
1308
+ except Exception as exc:
1309
+ print(f"analytics track failed: {exc}")
1310
+ return {"document_id": document.document_id}
1311
+
1312
+
1313
+ @router.get("/publish_doc")
1314
+ def publish_doc(
1315
+ doc_id: str,
1316
+ is_published: bool,
1317
+ current_user: models.User = Depends(get_current_user)
1318
+ ):
1319
+ with compair.Session() as session:
1320
+ # Check if user can publish or not
1321
+ if current_user.status == "suspended":
1322
+ raise HTTPException(status_code=403, detail="Your subscription has expired. Renew to publish.")
1323
+
1324
+ doc = session.query(models.Document).filter(
1325
+ models.Document.document_id == doc_id,
1326
+ models.Document.user_id == current_user.user_id
1327
+ )
1328
+ if not doc.first():
1329
+ raise HTTPException(status_code=404, detail="Document not found or you do not have permission to publish it.")
1330
+ if is_published!=doc.first().is_published:
1331
+ doc.update({'is_published': is_published})
1332
+ session.commit()
1333
+
1334
+
1335
+ @router.get("/delete_docs")
1336
+ def delete_docs(
1337
+ doc_ids: str,
1338
+ current_user: models.User = Depends(get_current_user)
1339
+ ):
1340
+ with compair.Session() as session:
1341
+ doc_ids = doc_ids.split(',')
1342
+ documents = session.query(models.Document).filter(
1343
+ models.Document.document_id.in_(doc_ids),
1344
+ models.Document.user_id == current_user.user_id
1345
+ )
1346
+
1347
+ for document in documents:
1348
+ doc_id = document.document_id
1349
+ doc_name = document.title
1350
+ doc_group: models.Group = document.groups[0]
1351
+ group_id = doc_group.group_id
1352
+ log_activity(
1353
+ session=session,
1354
+ user_id=current_user.user_id,
1355
+ group_id=group_id,
1356
+ action="delete",
1357
+ object_id=doc_id,
1358
+ object_name=doc_name,
1359
+ object_type="document"
1360
+ )
1361
+
1362
+ documents.delete()
1363
+ session.commit()
1364
+
1365
+
1366
+ @router.get("/delete_doc")
1367
+ def delete_doc(
1368
+ doc_id: str,
1369
+ current_user: models.User = Depends(get_current_user)
1370
+ ):
1371
+ with compair.Session() as session:
1372
+ document = session.query(models.Document).filter(
1373
+ models.Document.document_id == doc_id,
1374
+ models.Document.user_id == current_user.user_id
1375
+ )
1376
+ if not document.first():
1377
+ raise HTTPException(status_code=404, detail="Document not found or you do not have permission to delete it.")
1378
+
1379
+ doc_group = document[0].groups[0]
1380
+ group_id = doc_group.group_id
1381
+ doc_name = document[0].title
1382
+
1383
+ document.delete()
1384
+ session.commit()
1385
+
1386
+ log_activity(
1387
+ session=session,
1388
+ user_id=current_user.user_id,
1389
+ group_id=group_id,
1390
+ action="delete",
1391
+ object_id=doc_id,
1392
+ object_name=doc_name,
1393
+ object_type="document"
1394
+ )
1395
+
1396
+
1397
+ @router.post("/process_doc")
1398
+ async def process_doc(
1399
+ doc_id: str = Form(...),
1400
+ doc_text: str = Form(...),
1401
+ generate_feedback: bool = Form(True),
1402
+ current_user: models.User = Depends(get_current_user),
1403
+ analytics: Analytics = Depends(get_analytics),
1404
+ ) -> Mapping[str, str | None]:
1405
+ with compair.Session() as session:
1406
+ doc = session.query(models.Document).filter(models.Document.document_id == doc_id).first()
1407
+ if not doc:
1408
+ raise HTTPException(status_code=404, detail="Document not found")
1409
+ # Only allow the author to process/edit
1410
+ if doc.author_id != current_user.user_id:
1411
+ raise HTTPException(status_code=403, detail="Only the author can edit this document")
1412
+ # If the user is suspended, allow user to edit, but not receive any new feedback for docs
1413
+ if current_user.status == "suspended":
1414
+ generate_feedback=False
1415
+
1416
+ if IS_CLOUD:
1417
+ task = process_document_celery.delay(
1418
+ user_id=current_user.user_id,
1419
+ doc_id=doc_id,
1420
+ doc_text=doc_text,
1421
+ generate_feedback=generate_feedback
1422
+ )
1423
+ task_id = task.id
1424
+ else:
1425
+ process_document_celery(
1426
+ user_id=current_user.user_id,
1427
+ doc_id=doc_id,
1428
+ doc_text=doc_text,
1429
+ generate_feedback=generate_feedback
1430
+ )
1431
+ task_id = None
1432
+
1433
+ if generate_feedback:
1434
+ try:
1435
+ analytics.track("feedback_requested", current_user.user_id)
1436
+ except Exception as exc:
1437
+ print(f"analytics track failed: {exc}")
1438
+ return {"task_id": task_id}
1439
+
1440
+
1441
+ @router.post("/upload/ocr-file")
1442
+ async def upload_ocr_file(
1443
+ file: UploadFile = File(...),
1444
+ document_id: str = Form(None),
1445
+ current_user: models.User = Depends(get_current_user),
1446
+ settings: Settings = Depends(get_settings_dependency),
1447
+ ocr: OCRProvider = Depends(get_ocr),
1448
+ ):
1449
+ if not settings.ocr_enabled:
1450
+ raise HTTPException(status_code=501, detail="OCR is not available in this edition.")
1451
+
1452
+ file_bytes = await file.read()
1453
+ try:
1454
+ task_id = ocr.submit(
1455
+ user_id=current_user.user_id,
1456
+ filename=file.filename or "upload",
1457
+ data=file_bytes,
1458
+ document_id=document_id,
1459
+ )
1460
+ except NotImplementedError as exc:
1461
+ raise HTTPException(status_code=501, detail=str(exc)) from exc
1462
+ except Exception as exc:
1463
+ raise HTTPException(status_code=500, detail=str(exc)) from exc
1464
+
1465
+ return {"task_id": task_id}
1466
+
1467
+
1468
+ @router.get("/ocr-file-result/{task_id}")
1469
+ def get_ocr_file_result(
1470
+ task_id: str,
1471
+ current_user: models.User = Depends(get_current_user),
1472
+ settings: Settings = Depends(get_settings_dependency),
1473
+ ocr: OCRProvider = Depends(get_ocr),
1474
+ ):
1475
+ if not settings.ocr_enabled:
1476
+ raise HTTPException(status_code=501, detail="OCR is not available in this edition.")
1477
+
1478
+ try:
1479
+ status = ocr.status(task_id)
1480
+ except NotImplementedError as exc:
1481
+ raise HTTPException(status_code=501, detail=str(exc)) from exc
1482
+ except Exception as exc:
1483
+ raise HTTPException(status_code=500, detail=str(exc)) from exc
1484
+
1485
+ if isinstance(status, dict):
1486
+ payload = {"task_id": task_id, **status}
1487
+ else:
1488
+ payload = {"task_id": task_id, "status": status}
1489
+
1490
+ return payload
1491
+
1492
+
1493
+ @router.get("/status/{task_id}")
1494
+ async def get_process_status(
1495
+ task_id: str
1496
+ ):
1497
+ task_result = AsyncResult(task_id)
1498
+ print(task_result)
1499
+ print(task_result.status)
1500
+ if task_result.status == "SUCCESS":
1501
+ result = task_result.result
1502
+ elif task_result.status == "PENDING":
1503
+ result = "Task is still processing."
1504
+ else:
1505
+ result = "Task failed."
1506
+ return {"task_id": task_id, "status": task_result.status, "result": result}
1507
+
1508
+
1509
+ @router.get("/trial_status")
1510
+ def get_trial_status(
1511
+ current_user: models.User = Depends(get_current_user),
1512
+ billing: BillingProvider = Depends(get_billing),
1513
+ settings: Settings = Depends(get_settings_dependency),
1514
+ ):
1515
+ """Return the trial status for the authenticated user."""
1516
+ print('Trial 1')
1517
+ require_feature(HAS_TRIALS, "Trial management")
1518
+ with compair.Session() as session:
1519
+ trial_expiration = getattr(current_user, "trial_expiration_date", None)
1520
+ if trial_expiration is None:
1521
+ raise HTTPException(status_code=404, detail="Trial expiration date not found.")
1522
+ days_left = (trial_expiration - datetime.now(timezone.utc)).days
1523
+ show_banner = 0 < days_left <= 7
1524
+
1525
+ checkout_url = None
1526
+ if show_banner and settings.billing_enabled:
1527
+ require_feature(HAS_BILLING, "Billing integration")
1528
+ db_user = session.query(models.User).filter(models.User.user_id == current_user.user_id).first()
1529
+ if db_user:
1530
+ if not db_user.stripe_customer_id:
1531
+ try:
1532
+ customer_id = billing.ensure_customer(
1533
+ user_email=db_user.username,
1534
+ user_id=db_user.user_id,
1535
+ )
1536
+ except NotImplementedError:
1537
+ customer_id = None
1538
+ else:
1539
+ db_user.stripe_customer_id = customer_id
1540
+ session.commit()
1541
+ customer_id = db_user.stripe_customer_id
1542
+ if customer_id:
1543
+ try:
1544
+ session_info = billing.create_checkout_session(
1545
+ customer_id=customer_id,
1546
+ price_id=plan_to_id.get("individual_monthly"),
1547
+ qty=1,
1548
+ success_url=settings.stripe_success_url,
1549
+ cancel_url=settings.stripe_cancel_url,
1550
+ metadata={"plan": "individual_monthly", "user_id": db_user.user_id},
1551
+ )
1552
+ checkout_url = getattr(session_info, "url", None)
1553
+ if checkout_url is None and hasattr(session_info, "get"):
1554
+ checkout_url = session_info.get("url")
1555
+ except NotImplementedError:
1556
+ checkout_url = None
1557
+ except Exception as exc:
1558
+ print(f"Error generating Checkout URL: {exc}")
1559
+
1560
+ return {
1561
+ "status": current_user.status,
1562
+ "days_left": days_left,
1563
+ "show_banner": show_banner,
1564
+ "checkout_url": checkout_url,
1565
+ }
1566
+
1567
+
1568
+ @router.get("/load_chunks")
1569
+ def load_chunks(
1570
+ document_id: str,
1571
+ ) -> list[schema.Chunk]:
1572
+ with compair.Session() as session:
1573
+ chunks = session.query(models.Chunk).filter(
1574
+ models.Chunk.document_id==document_id
1575
+ )
1576
+ return chunks
1577
+
1578
+
1579
+ @router.get("/load_feedback")
1580
+ def load_feedback(
1581
+ chunk_id: str,
1582
+ ) -> schema.Feedback | None:
1583
+ with compair.Session() as session:
1584
+ feedback = session.query(models.Feedback).filter(
1585
+ models.Feedback.source_chunk_id==chunk_id
1586
+ )
1587
+ if len(feedback.all())>0:
1588
+ f = feedback[0]
1589
+ # Return user_feedback in the response
1590
+ return {
1591
+ "feedback_id": f.feedback_id,
1592
+ "source_chunk_id": f.source_chunk_id,
1593
+ "feedback": f.feedback,
1594
+ "user_feedback": f.user_feedback,
1595
+ "is_hidden": f.is_hidden,
1596
+ }
1597
+ return None
1598
+
1599
+
1600
+ @router.post("/feedback/{feedback_id}/hide")
1601
+ def hide_feedback(
1602
+ feedback_id: str,
1603
+ is_hidden: bool = Form(False)
1604
+ ):
1605
+ with compair.Session() as session:
1606
+ feedback = session.query(models.Feedback).filter(models.Feedback.feedback_id == feedback_id).first()
1607
+ if not feedback:
1608
+ raise HTTPException(status_code=404, detail="Feedback not found")
1609
+ feedback.is_hidden = is_hidden
1610
+ session.commit()
1611
+ return {"message": f"Feedback {'hidden' if is_hidden else 'unhidden'} successfully", "feedback_id": feedback_id}
1612
+
1613
+
1614
+
1615
+ @router.post("/feedback/{feedback_id}/rate")
1616
+ def rate_feedback(
1617
+ feedback_id: str,
1618
+ user_feedback: str = Body(..., embed=True) # expects "positive" or "negative"
1619
+ ):
1620
+ if user_feedback not in ("positive", "negative", None):
1621
+ raise HTTPException(status_code=400, detail="Invalid feedback label")
1622
+ with compair.Session() as session:
1623
+ feedback = session.query(models.Feedback).filter(models.Feedback.feedback_id == feedback_id).first()
1624
+ if not feedback:
1625
+ raise HTTPException(status_code=404, detail="Feedback not found")
1626
+ feedback.user_feedback = user_feedback
1627
+ session.commit()
1628
+ return {"message": "Feedback rated successfully", "feedback_id": feedback_id, "user_feedback": user_feedback}
1629
+
1630
+
1631
+ @router.get("/documents/{document_id}/feedback")
1632
+ def list_document_feedback(
1633
+ document_id: str,
1634
+ current_user: models.User = Depends(get_current_user)
1635
+ ):
1636
+ """Return feedback entries for a document the current user can access."""
1637
+ with compair.Session() as session:
1638
+ doc = session.query(models.Document).filter(models.Document.document_id == document_id).first()
1639
+ if not doc:
1640
+ raise HTTPException(status_code=404, detail="Document not found")
1641
+ # Access control: owner, member of a group with access, or published
1642
+ user_group_ids = {g.group_id for g in current_user.groups}
1643
+ doc_group_ids = {g.group_id for g in doc.groups}
1644
+ if not (doc.user_id == current_user.user_id or doc_group_ids & user_group_ids or doc.is_published):
1645
+ raise HTTPException(status_code=403, detail="Not authorized to access this document")
1646
+ # Join through chunks
1647
+ q = (
1648
+ session.query(models.Feedback)
1649
+ .join(models.Chunk, models.Feedback.source_chunk_id == models.Chunk.chunk_id)
1650
+ .filter(models.Chunk.document_id == document_id)
1651
+ .order_by(models.Feedback.timestamp.desc())
1652
+ )
1653
+ rows = q.all()
1654
+ return {
1655
+ "document_id": document_id,
1656
+ "count": len(rows),
1657
+ "feedback": [
1658
+ {
1659
+ "feedback_id": f.feedback_id,
1660
+ "chunk_id": f.source_chunk_id,
1661
+ "feedback": f.feedback,
1662
+ "user_feedback": f.user_feedback,
1663
+ "timestamp": f.timestamp,
1664
+ } for f in rows
1665
+ ],
1666
+ }
1667
+
1668
+
1669
+ @router.get("/load_references")
1670
+ def load_references(
1671
+ chunk_id: str,
1672
+ ) -> list[schema.Reference]:
1673
+ with compair.Session() as session:
1674
+ references = session.query(models.Reference).filter(
1675
+ models.Reference.source_chunk_id==chunk_id
1676
+ ).all()
1677
+ returned_references: list[schema.Reference] = [
1678
+ schema.Reference(
1679
+ reference_id=r.reference_id,
1680
+ source_chunk_id=r.source_chunk_id,
1681
+ reference_document_id=r.reference_document_id,
1682
+ document=schema.Document(
1683
+ document_id=r.document.document_id,
1684
+ user_id=r.document.user_id,
1685
+ author_id=r.document.author_id,
1686
+ title=r.document.title,
1687
+ content=r.document.content,
1688
+ doc_type=r.document.doc_type,
1689
+ datetime_created=r.document.datetime_created,
1690
+ datetime_modified=r.document.datetime_modified,
1691
+ is_published=r.document.is_published,
1692
+ groups=[{"group_id":"","datetime_created":datetime.now(),"name":""}],
1693
+ user=schema.User(
1694
+ user_id=r.document.user.user_id,
1695
+ username=r.document.user.username,
1696
+ name=r.document.user.name,
1697
+ groups=[{"group_id":"","datetime_created":datetime.now(),"name":""}],
1698
+ datetime_registered=r.document.user.datetime_registered,
1699
+ status=r.document.user.status
1700
+ )
1701
+ ),
1702
+ document_author=r.document.user.username
1703
+ ) for r in references
1704
+ ]
1705
+ print(f'Returned references: {returned_references}')
1706
+ return returned_references
1707
+
1708
+ @router.get("/verify-email")
1709
+ def verify_email(token: str):
1710
+ with compair.Session() as session:
1711
+ print(token)
1712
+ user = session.query(models.User).filter(models.User.verification_token == token).first()
1713
+ print(user)
1714
+ print(user.token_expiration)
1715
+ print(datetime.now(timezone.utc))
1716
+ if not user:
1717
+ raise HTTPException(status_code=400, detail="Invalid or expired token")
1718
+ if user.token_expiration < datetime.now(timezone.utc):
1719
+ raise HTTPException(status_code=400, detail="Token has expired")
1720
+ user.status = 'trial'
1721
+ user.status_change_date = datetime.now(timezone.utc)
1722
+ if HAS_TRIALS:
1723
+ user.trial_expiration_date = datetime.now(timezone.utc) + timedelta(days=30)
1724
+ user.verification_token = None # Invalidate the token
1725
+ session.commit()
1726
+
1727
+ pending_invitations = session.query(models.GroupInvitation).filter(
1728
+ models.GroupInvitation.email == user.username,
1729
+ models.GroupInvitation.status == "pending"
1730
+ ).all()
1731
+ for invitation in pending_invitations:
1732
+ # Mark as "sent"
1733
+ invitation.status = "sent"
1734
+ # Generate invitation token if not present
1735
+ if not invitation.token:
1736
+ invitation.token = secrets.token_urlsafe(32).lower()
1737
+ invitation.datetime_expiration = datetime.now(timezone.utc) + timedelta(days=7)
1738
+ session.commit()
1739
+
1740
+ # Send the actual group invitation email
1741
+ group = invitation.group
1742
+ inviter = invitation.inviter
1743
+ invitation_link = f"http://{WEB_URL}/accept-group-invitation?token={invitation.token}&user_id={user.user_id}"
1744
+ emailer.connect()
1745
+ emailer.send(
1746
+ subject="You’re Invited to Join a Group on Compair",
1747
+ sender=EMAIL_USER,
1748
+ receivers=[user.username],
1749
+ html=GROUP_INVITATION_TEMPLATE.replace(
1750
+ "{{inviter_name}}", inviter.name
1751
+ ).replace(
1752
+ "{{group_name}}", group.name
1753
+ ).replace(
1754
+ "{{invitation_link}}", invitation_link
1755
+ )
1756
+ )
1757
+
1758
+ return {"message": "Email verified successfully. Your free trial has started!"}
1759
+
1760
+ def is_valid_email(email):
1761
+ return re.match(r"[^@]+@[^@]+\.[^@]+", email)
1762
+
1763
+ @router.post("/sign-up")
1764
+ def sign_up(
1765
+ request: schema.SignUpRequest,
1766
+ analytics: Analytics = Depends(get_analytics),
1767
+ ) -> dict:
1768
+ print('1')
1769
+ if not is_valid_email(request.username):
1770
+ raise HTTPException(status_code=400, detail="Invalid email address")
1771
+ with compair.Session() as session:
1772
+ print('2')
1773
+ # Call internal function to create the user
1774
+ #user = create_user(email=request.email, password=request.password)
1775
+ user = create_user(
1776
+ username=request.username,
1777
+ name=request.name,
1778
+ password=request.password,
1779
+ groups=request.groups,
1780
+ session=session,
1781
+ referral_code=request.referral_code,
1782
+ )
1783
+ print('Passed create_user')
1784
+ # Generate the verification link
1785
+ verification_link = f"http://{WEB_URL}/verify-email?token={user.verification_token}"
1786
+ print('3?')
1787
+ emailer.connect()
1788
+ print('4??')
1789
+ # Send the verification email
1790
+ emailer.send(
1791
+ subject="Verify your email address",
1792
+ sender=EMAIL_USER,
1793
+ receivers=[user.username],
1794
+ html=ACCOUNT_VERIFY_TEMPLATE.replace("{{verification_link}}", verification_link)
1795
+ )
1796
+ print('The end???')
1797
+ return {"message": "Sign-up successful. Please check your email for verification."}
1798
+
1799
+ @router.post("/forgot-password")
1800
+ def forgot_password(request: schema.ForgotPasswordRequest) -> dict:
1801
+ print('1')
1802
+ with compair.Session() as session:
1803
+ print('2')
1804
+ user = session.query(models.User).filter(models.User.username == request.email).first()
1805
+ print(user)
1806
+ if not user:
1807
+ return {"message": "If the email exists, a reset link will be sent."}
1808
+
1809
+ # Generate reset token
1810
+ token, expiration = generate_verification_token() # Same function as before
1811
+ print(token)
1812
+ token = token.lower()
1813
+ user.reset_token = token
1814
+ user.token_expiration = expiration
1815
+ session.commit()
1816
+
1817
+ print('3')
1818
+ # Send email with reset link
1819
+ reset_link = f"http://{WEB_URL}/reset-password?token={token}"
1820
+ emailer.connect()
1821
+ print('4')
1822
+ emailer.send(
1823
+ subject="Password Reset Request",
1824
+ sender=EMAIL_USER,
1825
+ receivers=[request.email],
1826
+ html=PASSWORD_RESET_TEMPLATE.replace("{{reset_link}}", reset_link)
1827
+ )
1828
+ print('5')
1829
+ return {"message": "If the email exists, a reset link will be sent."}
1830
+
1831
+ @router.post("/reset-password")
1832
+ def reset_password(request: schema.ResetPasswordRequest) -> dict:
1833
+ with compair.Session() as session:
1834
+ print('1')
1835
+ print(request.token)
1836
+ user = session.query(models.User).filter(models.User.reset_token == request.token).first()
1837
+ print(user)
1838
+ if not user or user.token_expiration < datetime.now(timezone.utc):
1839
+ raise HTTPException(status_code=400, detail="Invalid or expired token")
1840
+
1841
+ # Update the password
1842
+ user.set_password(request.new_password)
1843
+ print('2')
1844
+ user.reset_token = None # Invalidate the token
1845
+ user.token_expiration = None
1846
+ print('3')
1847
+ session.commit()
1848
+ print('4')
1849
+
1850
+ return {"message": "Password has been reset successfully"}
1851
+
1852
+ @router.get("/admin/groups")
1853
+ def get_admin_groups(
1854
+ current_user: models.User = Depends(get_current_user),
1855
+ ):
1856
+ """Retrieve groups managed by the given user (admin)."""
1857
+ with compair.Session() as session:
1858
+ admin = session.query(models.Administrator).filter(models.Administrator.user_id == current_user.user_id).first()
1859
+ if not admin:
1860
+ return [] # Not an admin of any groups
1861
+
1862
+ # Only return groups where this user is an admin
1863
+ groups = admin.groups
1864
+ return [
1865
+ {
1866
+ "group_id": g.group_id,
1867
+ "name": g.name,
1868
+ "visibility": g.visibility,
1869
+ "category": g.category,
1870
+ "description": g.description,
1871
+ "group_image": g.group_image,
1872
+ "datetime_created": g.datetime_created,
1873
+ }
1874
+ for g in groups
1875
+ ]
1876
+
1877
+ @router.get("/admin/join_requests")
1878
+ def get_join_requests(
1879
+ group_id: str,
1880
+ current_user: models.User = Depends(get_current_user)
1881
+ ):
1882
+ """Retrieve pending join requests for a group."""
1883
+ with compair.Session() as session:
1884
+ requests = session.query(models.JoinRequest).filter(
1885
+ models.JoinRequest.group_id == group_id
1886
+ ).filter(
1887
+ models.JoinRequest.group.has(models.Group.admins.any(models.Administrator.user_id == current_user.user_id))
1888
+ ).all()
1889
+ return [
1890
+ {"request_id": r.request_id, "user_name": r.user.name, "datetime_requested": r.datetime_requested}
1891
+ for r in requests
1892
+ ]
1893
+
1894
+ @router.post("/admin/approve_request")
1895
+ def approve_request(
1896
+ request_id: int,
1897
+ current_user: models.User = Depends(get_current_user)
1898
+ ):
1899
+ """Approve a join request."""
1900
+ with compair.Session() as session:
1901
+ request = session.query(models.JoinRequest).filter(
1902
+ models.JoinRequest.request_id == request_id
1903
+ ).filter(
1904
+ models.JoinRequest.group.has(models.Group.admins.any(models.Administrator.user_id == current_user.user_id))
1905
+ ).first()
1906
+ if not request:
1907
+ raise HTTPException(status_code=404, detail="Request not found")
1908
+
1909
+ group: models.Group = request.group
1910
+ user: models.User = request.user
1911
+ group.users.append(user)
1912
+ session.delete(request)
1913
+ session.commit()
1914
+
1915
+ log_activity(
1916
+ session=session,
1917
+ user_id=user.user_id,
1918
+ group_id=group.group_id,
1919
+ action="join",
1920
+ object_id=group.group_id,
1921
+ object_name=group.name,
1922
+ object_type="group"
1923
+ )
1924
+
1925
+ return {"message": "Request approved successfully"}
1926
+
1927
+ @router.post("/admin/reject_request")
1928
+ def reject_request(
1929
+ request_id: int,
1930
+ current_user: models.User = Depends(get_current_user)
1931
+ ):
1932
+ """Reject a join request."""
1933
+ with compair.Session() as session:
1934
+ request = session.query(models.JoinRequest).filter(
1935
+ models.JoinRequest.request_id == request_id,
1936
+ models.JoinRequest.group.has(models.Group.admins.any(models.Administrator.user_id == current_user.user_id))
1937
+ ).first()
1938
+ if not request:
1939
+ raise HTTPException(status_code=404, detail="Request not found")
1940
+
1941
+ session.delete(request)
1942
+ session.commit()
1943
+
1944
+ return {"message": "Request rejected successfully"}
1945
+
1946
+ @router.post("/admin/update_group")
1947
+ def update_group(
1948
+ group_id: str,
1949
+ name: Optional[str] = Form(None),
1950
+ visibility: Optional[str] = Form(None),
1951
+ category: Optional[str] = Form(None),
1952
+ description: Optional[str] = Form(None),
1953
+ current_user: models.User = Depends(get_current_user)
1954
+ ) -> dict:
1955
+ """Update group settings."""
1956
+ print('1')
1957
+ with compair.Session() as session:
1958
+ group = session.query(models.Group).filter(
1959
+ models.Group.group_id == group_id,
1960
+ models.Group.admins.any(models.Administrator.user_id == current_user.user_id)
1961
+ ).first()
1962
+ print('2')
1963
+ if not group:
1964
+ raise HTTPException(status_code=404, detail="Group not found")
1965
+ print(group)
1966
+ print(group_id)
1967
+ print(name)
1968
+ print(visibility)
1969
+ print(category)
1970
+ print(description)
1971
+ if name:
1972
+ group.name = name
1973
+ if visibility:
1974
+ group.visibility = visibility
1975
+ if category:
1976
+ group.category = category
1977
+ if description:
1978
+ group.description = description
1979
+ session.commit()
1980
+
1981
+ return {"message": "Group updated successfully"}
1982
+
1983
+ @router.post("/admin/invite_member")
1984
+ def invite_member_to_group(
1985
+ request: schema.InviteMemberRequest,
1986
+ current_user: models.User = Depends(get_current_user)
1987
+ ):
1988
+ """
1989
+ Invite an existing Compair user (by username or email) to join a group.
1990
+ If the user exists, create a GroupInvitation and send an email.
1991
+ """
1992
+ admin_id = request.admin_id,
1993
+ group_id = request.group_id,
1994
+ username = request.username
1995
+ with compair.Session() as session:
1996
+ group = session.query(models.Group).filter(
1997
+ models.Group.group_id == group_id,
1998
+ models.Group.admins.any(models.Administrator.user_id == admin_id)
1999
+ ).first()
2000
+ if not group:
2001
+ raise HTTPException(status_code=404, detail="Group not found")
2002
+
2003
+ admin = session.query(models.Administrator).filter(
2004
+ models.Administrator.user_id == admin_id,
2005
+ models.Administrator.groups.contains(group)
2006
+ ).first()
2007
+ if not admin:
2008
+ raise HTTPException(status_code=403, detail="You are not authorized to manage this group")
2009
+
2010
+ # Try to find user by username or email
2011
+ user = session.query(models.User).filter(
2012
+ (models.User.username == username)# | (models.User.email == username_or_email) # if they were different
2013
+ ).first()
2014
+ if not user:
2015
+ raise HTTPException(status_code=404, detail="User not found")
2016
+
2017
+ # Check if already a member or already invited
2018
+ if user in group.users:
2019
+ raise HTTPException(status_code=400, detail="User is already a member")
2020
+ existing_invite = session.query(models.GroupInvitation).filter(
2021
+ models.GroupInvitation.group_id == group_id,
2022
+ models.GroupInvitation.email == user.username
2023
+ ).first()
2024
+ if existing_invite:
2025
+ raise HTTPException(status_code=400, detail="User already invited")
2026
+
2027
+ # Create invitation
2028
+ token = secrets.token_urlsafe(32).lower()
2029
+ invitation = models.GroupInvitation(
2030
+ group_id=group_id,
2031
+ inviter_id=admin_id,
2032
+ token=token,
2033
+ email=user.username,
2034
+ datetime_expiration=datetime.utcnow() + timedelta(days=7),
2035
+ status='pending'
2036
+ )
2037
+ session.add(invitation)
2038
+ session.commit()
2039
+
2040
+ # Send email
2041
+ invitation_link = f"http://{WEB_URL}/accept-group-invitation?token={token}&user_id={user.user_id}"
2042
+ emailer.connect()
2043
+ emailer.send(
2044
+ subject="You’re Invited to Join a Group on Compair",
2045
+ sender=EMAIL_USER,
2046
+ receivers=[user.username],
2047
+ html=GROUP_INVITATION_TEMPLATE.replace(
2048
+ "{{ inviter_name }}", admin.user.name
2049
+ ).replace(
2050
+ "{{ group_name }}", group.name
2051
+ ).replace(
2052
+ "{{ invitation_link }}", invitation_link
2053
+ )
2054
+ )
2055
+ invitation.status = 'sent'
2056
+ session.commit()
2057
+ return {"message": "Invitation sent successfully"}
2058
+
2059
+ @router.post("/admin/invite_new_user")
2060
+ def invite_new_user_to_group(
2061
+ request: schema.InviteToGroupRequest,
2062
+ current_user: models.User = Depends(get_current_user)
2063
+ ) -> dict:
2064
+ admin_id = request.admin_id
2065
+ group_id = request.group_id
2066
+ email = request.email
2067
+ """Generate an invitation link to Compair and send it via email, logging a Group request to send on signup."""
2068
+ with compair.Session() as session:
2069
+ group = session.query(models.Group).filter(
2070
+ models.Group.group_id == group_id,
2071
+ models.Group.admins.any(models.Administrator.user_id == admin_id)
2072
+ ).first()
2073
+ if not group:
2074
+ raise HTTPException(status_code=404, detail="Group not found")
2075
+
2076
+ admin = session.query(models.Administrator).filter(
2077
+ models.Administrator.user_id == admin_id,
2078
+ models.Administrator.groups.contains(group)
2079
+ ).first()
2080
+ if not admin:
2081
+ raise HTTPException(status_code=403, detail="You are not authorized to manage this group")
2082
+
2083
+ referral_link = generate_referral_link(admin.user.referral_code)
2084
+
2085
+ # Send email notification
2086
+ emailer.connect()
2087
+ # Track pending group invitation for this email if group_id is provided
2088
+ if group_id:
2089
+ # Store a "pending" group invitation for this email
2090
+ token = secrets.token_urlsafe(32).lower()
2091
+ invitation = models.GroupInvitation(
2092
+ group_id=group_id,
2093
+ inviter_id=admin.user.user_id,
2094
+ token=token,
2095
+ email=email,
2096
+ datetime_expiration=datetime.utcnow() + timedelta(days=7),
2097
+ status="pending"
2098
+ )
2099
+ session.add(invitation)
2100
+ session.commit()
2101
+ # Send email notification
2102
+ emailer.send(
2103
+ subject="You're Invited to Compair!",
2104
+ sender=EMAIL_USER,
2105
+ receivers=[email],
2106
+ html=INDIVIDUAL_INVITATION_TEMPLATE.replace(
2107
+ "{{inviter_name}}", admin.user.name
2108
+ ).replace(
2109
+ "{{referral_link}}", referral_link,
2110
+ )
2111
+ )
2112
+
2113
+ return {"message": "Invitation to join Compair successful."}
2114
+
2115
+ @router.post("/admin/remove_member")
2116
+ def remove_member(
2117
+ request: schema.RemoveMemberRequest,
2118
+ current_user: models.User = Depends(get_current_user)
2119
+ ):
2120
+ """Remove a member from a group."""
2121
+ with compair.Session() as session:
2122
+ # Validate that the group exists
2123
+ group = session.query(models.Group).filter(models.Group.group_id == request.group_id).first()
2124
+ if not group:
2125
+ raise HTTPException(status_code=404, detail="Group not found")
2126
+
2127
+ # Validate that the current user is an admin of the group
2128
+ if not any(admin.user_id == current_user.user_id for admin in group.admins):
2129
+ raise HTTPException(status_code=403, detail="You are not authorized to manage this group")
2130
+
2131
+ # Validate that the user to be removed is a member of the group
2132
+ user = session.query(models.User).filter(models.User.user_id == request.user_id).first()
2133
+ if not user or user not in group.users:
2134
+ raise HTTPException(status_code=404, detail="User is not a member of this group")
2135
+
2136
+ # Remove the user from the group
2137
+ group.users.remove(user)
2138
+ session.commit()
2139
+
2140
+ # Log the activity
2141
+ log_activity(
2142
+ session=session,
2143
+ user_id=current_user.user_id,
2144
+ group_id=group.group_id,
2145
+ action="remove_member",
2146
+ object_id=user.user_id,
2147
+ object_name=user.name,
2148
+ object_type="user"
2149
+ )
2150
+
2151
+ return {"message": f"User {user.name} has been removed from the group"}
2152
+
2153
+ @router.get("/accept_group_invitation")
2154
+ def accept_group_invitation(
2155
+ token: str,
2156
+ current_user: models.User = Depends(get_current_user)
2157
+ ):
2158
+ """Accept a group invitation using a token."""
2159
+ with compair.Session() as session:
2160
+ invitation = session.query(models.GroupInvitation).filter(models.GroupInvitation.token == token).first()
2161
+ if not invitation:
2162
+ raise HTTPException(status_code=404, detail="Invalid or expired invitation")
2163
+
2164
+ if invitation.datetime_expiration < datetime.now(timezone.utc):
2165
+ invitation.status = "expired"
2166
+ session.commit()
2167
+ raise HTTPException(status_code=400, detail="Invitation has expired")
2168
+
2169
+ group = invitation.group
2170
+ user = session.query(models.User).filter(models.User.user_id == current_user.user_id).first()
2171
+
2172
+ # Add the user to the group
2173
+ group.users.append(user)
2174
+ invitation.status = "accepted"
2175
+ session.commit()
2176
+
2177
+ return {"message": "You have successfully joined the group"}
2178
+
2179
+ @router.get("/accept_team_invitation")
2180
+ def accept_team_invitation(
2181
+ token: str,
2182
+ current_user: models.User = Depends(get_current_user)
2183
+ ):
2184
+ """Accept a team invitation using a token."""
2185
+ require_feature(HAS_TEAM, "Team collaboration")
2186
+ with compair.Session() as session:
2187
+ invitation = session.query(models.TeamInvitation).filter(
2188
+ models.TeamInvitation.invitation_id == token,
2189
+ models.TeamInvitation.status == "sent",
2190
+ ).first()
2191
+ if not invitation:
2192
+ raise HTTPException(status_code=404, detail="Invalid or expired invitation.")
2193
+
2194
+ if hasattr(current_user, "team_id"):
2195
+ current_user.team_id = invitation.team_id
2196
+ invitation.status = "accepted"
2197
+ session.commit()
2198
+
2199
+ return {"message": "You have successfully joined the team!"}
2200
+
2201
+ def get_feedback_tooltip(
2202
+ feedback_id: str,
2203
+ ) -> str:
2204
+ """Retrieve feedback details for tooltip display."""
2205
+ with compair.Session() as session:
2206
+ feedbacks = session.query(models.Feedback).filter(models.Feedback.feedback_id == feedback_id).all()
2207
+ return " | ".join(f.feedback for f in feedbacks)
2208
+
2209
+ @router.get("/get_activity_feed")
2210
+ def get_activity_feed(
2211
+ user_id: str,
2212
+ page: int = 1,
2213
+ page_size: int = 10,
2214
+ include_own_activities: bool = True,
2215
+ current_user: models.User = Depends(get_current_user)
2216
+ ):
2217
+ """Retrieve recent activities for a user's groups."""
2218
+ require_feature(HAS_ACTIVITY, "Activity feed")
2219
+ with compair.Session() as session:
2220
+ # Get user's groups
2221
+
2222
+ # Query recent activities related to user's groups
2223
+ group_ids = [g.group_id for g in current_user.groups]
2224
+
2225
+ q = (
2226
+ session.query(models.Activity)
2227
+ .filter(models.Activity.group_id.in_(group_ids))
2228
+ .order_by(models.Activity.timestamp.desc())
2229
+ )
2230
+ if not include_own_activities:
2231
+ q = q.filter(models.Activity.user_id != current_user.user_id)
2232
+
2233
+ total_count = q.count()
2234
+ offset = (page - 1) * page_size
2235
+ activities = q.offset(offset).limit(page_size).all()
2236
+
2237
+ return {
2238
+ "activities": [
2239
+ {
2240
+ "user": activity.user.name,
2241
+ "user_id": activity.user_id,
2242
+ "group_id": activity.group_id,
2243
+ "action": activity.action,
2244
+ "object": f"{activity.object_type} {activity.object_name}",
2245
+ "object_type": activity.object_type,
2246
+ "object_name": activity.object_name,
2247
+ "object_id": activity.object_id,
2248
+ "timestamp": activity.timestamp,
2249
+ "tooltip": get_feedback_tooltip(activity.object_id) if activity.action == "provided feedback" else None
2250
+ }
2251
+ for activity in activities
2252
+ ],
2253
+ "total_count": total_count
2254
+ }
2255
+
2256
+ @router.delete("/delete_group")
2257
+ def delete_group(
2258
+ group_id: str,
2259
+ current_user: models.User = Depends(get_current_user)
2260
+ ):
2261
+ """Delete a group. Allowed only if the current user is an admin of the group.
2262
+
2263
+ This removes the group and its associations. Documents remain but lose their link to this group.
2264
+ """
2265
+ with compair.Session() as session:
2266
+ group = session.query(models.Group).filter(models.Group.group_id == group_id).first()
2267
+ if not group:
2268
+ raise HTTPException(status_code=404, detail="Group not found")
2269
+ # Check admin rights
2270
+ is_admin = any(a.user_id == current_user.user_id for a in group.admins)
2271
+ if not is_admin:
2272
+ raise HTTPException(status_code=403, detail="Not authorized to delete this group")
2273
+ name = group.name
2274
+ session.delete(group)
2275
+ session.commit()
2276
+ # Log activity
2277
+ try:
2278
+ log_activity(
2279
+ session=session,
2280
+ user_id=current_user.user_id,
2281
+ group_id=group_id,
2282
+ action="delete",
2283
+ object_id=group_id,
2284
+ object_name=name,
2285
+ object_type="group",
2286
+ )
2287
+ except Exception:
2288
+ pass
2289
+ return {"message": "Group deleted", "group_id": group_id}
2290
+
2291
+ plan_to_id = {
2292
+ #"starter": "price_1Rpb5VEPPPghXLIJkgJyWBXb",
2293
+ #"standard": "price_1Qv7etEPPPghXLIJ5SR8KgXD",
2294
+ "individual_monthly": "price_1Ry34UEPPPghXLIJ4OLQNEna",#"price_1RwnRaEPPPghXLIJIPyeVxLq",
2295
+ "individual_annual": "price_1Ry34NEPPPghXLIJeftgEYOD",#"price_1RwoQkEPPPghXLIJyNUCRRo3",
2296
+ "team_monthly": "price_1Ry34HEPPPghXLIJoEyGRm7h",#"price_1RwoRDEPPPghXLIJt7uTCC7E",
2297
+ "team_annual": "price_1Ry34REPPPghXLIJXhLhwpBW"#"price_1RwoP9EPPPghXLIJo16OPgMu",
2298
+ }
2299
+
2300
+ @router.post("/create-checkout-session")
2301
+ def create_checkout_session(
2302
+ checkout_map: Mapping,
2303
+ current_user: models.User = Depends(get_current_user),
2304
+ billing: BillingProvider = Depends(get_billing),
2305
+ settings: Settings = Depends(get_settings_dependency),
2306
+ ):
2307
+ """Create a checkout session through the configured billing provider."""
2308
+
2309
+ require_cloud("Billing")
2310
+
2311
+ if not settings.billing_enabled:
2312
+ raise HTTPException(status_code=501, detail="Billing is not available in this edition.")
2313
+
2314
+ plan = checkout_map.get("plan", "individual_monthly")
2315
+ team_emails = checkout_map.get("team_emails", []) or []
2316
+ if isinstance(team_emails, str):
2317
+ team_emails = [email.strip() for email in team_emails.split(',') if email.strip()]
2318
+ if plan not in plan_to_id:
2319
+ raise HTTPException(status_code=400, detail="Unsupported plan selected.")
2320
+
2321
+ if plan.startswith("team") and len(team_emails) < 3:
2322
+ raise HTTPException(status_code=400, detail="Team plans require at least 4 emails.")
2323
+
2324
+ price_id = plan_to_id[plan]
2325
+
2326
+ quantity = 1 if plan.startswith("individual") else len(team_emails)
2327
+
2328
+ try:
2329
+ with compair.Session() as session:
2330
+ db_user = session.query(models.User).filter(models.User.user_id == current_user.user_id).first()
2331
+ if not db_user:
2332
+ raise HTTPException(status_code=404, detail="User not found.")
2333
+
2334
+ if not db_user.stripe_customer_id:
2335
+ log_event("stripe_customer_create_attempt", user_id=db_user.user_id)
2336
+ try:
2337
+ customer_id = billing.ensure_customer(
2338
+ user_email=db_user.username,
2339
+ user_id=db_user.user_id,
2340
+ )
2341
+ except NotImplementedError as exc:
2342
+ raise HTTPException(status_code=501, detail=str(exc)) from exc
2343
+ db_user.stripe_customer_id = customer_id
2344
+ session.commit()
2345
+ log_event("stripe_customer_created", user_id=db_user.user_id, customer_id=customer_id)
2346
+
2347
+ customer_id = db_user.stripe_customer_id
2348
+
2349
+ if plan.startswith("team"):
2350
+ require_feature(HAS_TEAM, "Team collaboration")
2351
+ team_name = checkout_map.get("team_name", f"{db_user.name}'s Team")
2352
+ team = models.Team(
2353
+ name=team_name,
2354
+ total_documents_limit=100 * len(team_emails),
2355
+ daily_feedback_limit=50 * len(team_emails),
2356
+ )
2357
+ session.add(team)
2358
+ session.commit()
2359
+
2360
+ for email in team_emails:
2361
+ existing_user = session.query(models.User).filter(models.User.username == email).first()
2362
+ if not existing_user:
2363
+ invitation = models.TeamInvitation(
2364
+ email=email,
2365
+ inviter_id=db_user.user_id,
2366
+ team_id=team.team_id,
2367
+ status="pending",
2368
+ datetime_created=datetime.now(timezone.utc),
2369
+ )
2370
+ session.add(invitation)
2371
+ session.commit()
2372
+
2373
+ log_event("stripe_checkout_session_start", user_id=db_user.user_id)
2374
+
2375
+ try:
2376
+ session_info = billing.create_checkout_session(
2377
+ customer_id=customer_id,
2378
+ price_id=price_id,
2379
+ qty=quantity,
2380
+ success_url=settings.stripe_success_url,
2381
+ cancel_url=settings.stripe_cancel_url,
2382
+ metadata={"plan": plan, "user_id": db_user.user_id},
2383
+ )
2384
+ except NotImplementedError as exc:
2385
+ raise HTTPException(status_code=501, detail=str(exc)) from exc
2386
+ except Exception as exc:
2387
+ raise HTTPException(status_code=500, detail=str(exc)) from exc
2388
+
2389
+ session_id = getattr(session_info, "id", None)
2390
+ session_url = getattr(session_info, "url", None)
2391
+ if session_id is None and hasattr(session_info, "get"):
2392
+ session_id = session_info.get("id")
2393
+ if session_url is None and hasattr(session_info, "get"):
2394
+ session_url = session_info.get("url")
2395
+ if not session_id or not session_url:
2396
+ raise HTTPException(status_code=500, detail="Billing provider did not return a checkout URL.")
2397
+
2398
+ if hasattr(db_user, "checkout_session"):
2399
+ db_user.checkout_session = session_id
2400
+ if hasattr(db_user, "plan"):
2401
+ db_user.plan = plan
2402
+ session.commit()
2403
+ log_event("stripe_checkout_session_created", user_id=db_user.user_id, session_id=session_id)
2404
+ return {"url": session_url}
2405
+ except HTTPException:
2406
+ raise
2407
+ except Exception as exc:
2408
+ raise HTTPException(status_code=500, detail=str(exc)) from exc
2409
+
2410
+
2411
+ @router.get("/redirect-to-checkout")
2412
+ def redirect_to_checkout(
2413
+ user_id: str,
2414
+ billing: BillingProvider = Depends(get_billing),
2415
+ settings: Settings = Depends(get_settings_dependency),
2416
+ ):
2417
+ require_cloud("Billing")
2418
+ if not settings.billing_enabled:
2419
+ raise HTTPException(status_code=404, detail="Not found")
2420
+
2421
+ with compair.Session() as session:
2422
+ user = session.query(models.User).filter(models.User.user_id == user_id).first()
2423
+ if not user or not user.checkout_session:
2424
+ raise HTTPException(status_code=500, detail=f"No checkout session found for user {user_id}")
2425
+
2426
+ session_id = user.checkout_session
2427
+ try:
2428
+ checkout_url = billing.get_checkout_url(session_id)
2429
+ except NotImplementedError as exc:
2430
+ raise HTTPException(status_code=501, detail=str(exc)) from exc
2431
+
2432
+ user.checkout_session = None
2433
+ session.commit()
2434
+
2435
+ return {"url": checkout_url}
2436
+
2437
+
2438
+
2439
+ def track_payment_method(fingerprint: str):
2440
+ """
2441
+ Track the number of credits associated with a payment fingerprint.
2442
+ """
2443
+ with compair.Session() as session:
2444
+ total_credits = session.query(models.User).filter(
2445
+ models.User.payment_fingerprint == fingerprint,
2446
+ models.User.referral_credits > 0
2447
+ ).count()
2448
+
2449
+ max_credits_per_payment = 3
2450
+ if total_credits >= max_credits_per_payment:
2451
+ raise HTTPException(
2452
+ status_code=400,
2453
+ detail="Referral credit limit reached for this payment method."
2454
+ )
2455
+
2456
+ def handle_successful_checkout(session):
2457
+ """
2458
+ Capture the payment method during Stripe Checkout.
2459
+ """
2460
+ print('made it to checkout event')
2461
+ print(session)
2462
+ customer_id = session["customer"]
2463
+ if customer_id is None:
2464
+ customer_id = session["customer_details"]["email"]
2465
+
2466
+ payment_method = session.get("payment_method") # Might not exist if using trial
2467
+ print(customer_id)
2468
+ print(payment_method)
2469
+ if payment_method:
2470
+ with compair.Session() as session:
2471
+ user = session.query(models.User).filter(
2472
+ models.User.stripe_customer_id == customer_id
2473
+ ).first()
2474
+ print(user)
2475
+ if user:
2476
+ user.payment_fingerprint = payment_method
2477
+ track_payment_method(user.payment_fingerprint)
2478
+ session.commit()
2479
+
2480
+ payment_status = session["payment_status"]
2481
+ if payment_status != "paid":
2482
+ return
2483
+
2484
+ with compair.Session() as db_session:
2485
+ user = db_session.query(models.User).filter(
2486
+ models.User.stripe_customer_id == customer_id
2487
+ ).first()
2488
+ if not user:
2489
+ return
2490
+
2491
+ if user.plan.startswith("team"):
2492
+ team = db_session.query(models.Team).filter(
2493
+ models.Team.ownership.has(owner_id=user.user_id)
2494
+ ).first()
2495
+ if not team:
2496
+ return
2497
+
2498
+ # Update invitations with the new team ID and send emails
2499
+ invitations = db_session.query(models.TeamInvitation).filter(
2500
+ models.TeamInvitation.inviter_id == user.user_id,
2501
+ models.TeamInvitation.status == "pending",
2502
+ ).all()
2503
+ for invitation in invitations:
2504
+ existing_user = db_session.query(models.User).filter(
2505
+ models.User.username == invitation.email
2506
+ ).first()
2507
+
2508
+ if existing_user:
2509
+ # Send team invitation email to existing user
2510
+ emailer.connect()
2511
+ emailer.send(
2512
+ subject="You're Invited to Join a Team on Compair",
2513
+ sender=EMAIL_USER,
2514
+ receivers=[invitation.email],
2515
+ html=f"""
2516
+ <p>{user.name} has invited you to join their team on Compair!</p>
2517
+ <p>Click <a href="https://{WEB_URL}/accept-invitation?token={invitation.invitation_id}">here</a> to join.</p>
2518
+ """
2519
+ )
2520
+ invitation.status = "sent"
2521
+ else:
2522
+ # Send invitation to join Compair
2523
+ send_invite(
2524
+ invitee_emails=invitation.email,
2525
+ current_user=user,
2526
+ )
2527
+
2528
+ def handle_successful_payment_intent(intent):
2529
+ """
2530
+ Capture the payment method when the first invoice is paid.
2531
+ """
2532
+ ### Get / Store payment intent to customer_id (if both available!) if payment_method not available
2533
+ customer_id = intent["customer"]
2534
+ payment_method = intent.get("payment_method")
2535
+ print("Got to intent")
2536
+ print(intent)
2537
+ print(payment_method)
2538
+ if payment_method:
2539
+ with compair.Session() as session:
2540
+ user = session.query(models.User).filter(
2541
+ models.User.stripe_customer_id == customer_id
2542
+ ).first()
2543
+ if user and hasattr(user, "payment_fingerprint"):
2544
+ user.payment_fingerprint = payment_method
2545
+ track_payment_method(user.payment_fingerprint)
2546
+ session.commit()
2547
+
2548
+ def handle_successful_invoice_payment(
2549
+ invoice,
2550
+ billing: BillingProvider,
2551
+ analytics: Analytics | None = None,
2552
+ ):
2553
+ """
2554
+ Capture the payment method when the first invoice is paid.
2555
+ """
2556
+ ### Use payment_intent to lookup customer (if latter is not available) to store payment_method
2557
+ customer_id = invoice["customer"]
2558
+ price_id = None
2559
+ for line in invoice["lines"]["data"]:
2560
+ if "price" in line and "id" in line["price"]:
2561
+ price_id = line["price"]["id"]
2562
+ break
2563
+
2564
+ plan = None
2565
+ for k, v in plan_to_id.items():
2566
+ if v == price_id:
2567
+ plan = k
2568
+ break
2569
+
2570
+ payment_method = invoice.get("payment_method")
2571
+ print("Got to payment")
2572
+ print(invoice)
2573
+ print(payment_method)
2574
+ if payment_method:
2575
+ with compair.Session() as session:
2576
+ user = session.query(models.User).filter(
2577
+ models.User.stripe_customer_id == customer_id
2578
+ ).first()
2579
+ if user:
2580
+ user.payment_fingerprint = payment_method
2581
+ track_payment_method(user.payment_fingerprint)
2582
+ session.commit()
2583
+
2584
+ with compair.Session() as session:
2585
+ user = session.query(models.User).filter(
2586
+ models.User.stripe_customer_id == customer_id
2587
+ ).first()
2588
+ print(f'invoice user: {user}')
2589
+ if user:
2590
+ # user.stripe_subscription_id = subscription_id ## Can track if needed
2591
+ user.status = 'active' # Mark the user as subscribed
2592
+ if hasattr(user, "last_payment_date"):
2593
+ user.last_payment_date = datetime.now(timezone.utc)
2594
+ user.status_change_date = datetime.now(timezone.utc)
2595
+ if plan and hasattr(user, "plan"):
2596
+ user.plan = plan
2597
+ print(user.status)
2598
+ # Handle referrer logic
2599
+ if hasattr(user, "referred_by") and user.referred_by:
2600
+ print(user.referred_by)
2601
+ referrer = session.query(models.User).filter(
2602
+ models.User.user_id == user.referred_by
2603
+ ).first()
2604
+ print(referrer)
2605
+ if referrer and hasattr(referrer, "pending_referral_credits") and referrer.pending_referral_credits > 0:
2606
+ # Convert pending credits to earned credits
2607
+ if hasattr(referrer, "referral_credits"):
2608
+ referrer.referral_credits += 10
2609
+ referrer.pending_referral_credits -= 10
2610
+ # Send email notification
2611
+ emailer.connect()
2612
+ emailer.send(
2613
+ subject="You've Earned a Free Month!",
2614
+ sender=EMAIL_USER,
2615
+ receivers=[referrer.username],
2616
+ html=REFERRAL_CREDIT_TEMPLATE.replace(
2617
+ "{{user_name}}", referrer.name
2618
+ ).replace(
2619
+ "{{referral_credits}}", str(referrer.referral_credits),
2620
+ )
2621
+ )
2622
+ # Create and apply a $15 coupon in Stripe ($25 if team plan)
2623
+ amount = 15
2624
+ if referrer.plan == 'team':
2625
+ amount=25
2626
+ try:
2627
+ coupon_id = billing.create_coupon(amount)
2628
+ print(f'Coupon ID 2:{coupon_id}')
2629
+ if referrer.stripe_customer_id:
2630
+ billing.apply_coupon(customer_id=referrer.stripe_customer_id, coupon_id=coupon_id)
2631
+ except NotImplementedError:
2632
+ pass
2633
+
2634
+ if analytics:
2635
+ try:
2636
+ analytics.track("subscription_created", user.user_id)
2637
+ except Exception as exc:
2638
+ print(f"analytics track failed: {exc}")
2639
+
2640
+ session.commit()
2641
+
2642
+ def handle_subscription_cancellation(subscription):
2643
+ """
2644
+ Revoke access if a subscription is canceled.
2645
+ """
2646
+ customer_id = subscription["customer"]
2647
+ print('Cancelled')
2648
+ print(subscription)
2649
+ with compair.Session() as session:
2650
+ user = session.query(models.User).filter(
2651
+ models.User.stripe_customer_id == customer_id
2652
+ ).first()
2653
+
2654
+ if user:
2655
+ user.status = 'suspended' # Remove access
2656
+ user.status_change_date = datetime.now(timezone.utc)
2657
+ session.commit()
2658
+
2659
+ def handle_failed_payment(invoice):
2660
+ """
2661
+ Handle failed payment attempts.
2662
+ """
2663
+ customer_id = invoice["customer"]
2664
+ print('Failure')
2665
+ print(invoice)
2666
+ with compair.Session() as session:
2667
+ user = session.query(models.User).filter(
2668
+ models.User.stripe_customer_id == customer_id
2669
+ ).first()
2670
+
2671
+ if user:
2672
+ user.status = 'suspended' # Suspend access until payment is made
2673
+ user.status_change_date = datetime.now(timezone.utc)
2674
+ session.commit()
2675
+
2676
+ @router.post("/webhook")
2677
+ async def stripe_webhook(
2678
+ request: Request,
2679
+ billing: BillingProvider = Depends(get_billing),
2680
+ analytics: Analytics = Depends(get_analytics),
2681
+ settings: Settings = Depends(get_settings_dependency),
2682
+ ):
2683
+ if not settings.billing_enabled:
2684
+ raise HTTPException(status_code=404, detail="Not found")
2685
+
2686
+ payload = await request.body()
2687
+ sig_header = request.headers.get("stripe-signature")
2688
+ #print(f'payload: {payload}')
2689
+ try:
2690
+ event = billing.construct_event(payload, sig_header)
2691
+ print(f'webhook event: {event["type"]}')
2692
+
2693
+ log_event("stripe_webhook_start", event_type=event['type'], data=event['data']['object'])
2694
+
2695
+ if event["type"] == "payment_intent.succeeded":
2696
+ intent = event["data"]["object"]
2697
+ handle_successful_payment_intent(intent)
2698
+
2699
+ elif event["type"] == "invoice.payment_succeeded":
2700
+ invoice = event["data"]["object"]
2701
+ handle_successful_invoice_payment(invoice, billing, analytics)
2702
+
2703
+ elif event["type"] == "customer.subscription.deleted":
2704
+ subscription = event["data"]["object"]
2705
+ handle_subscription_cancellation(subscription)
2706
+
2707
+ elif event["type"] == "invoice.payment_failed":
2708
+ invoice = event["data"]["object"]
2709
+ handle_failed_payment(invoice)
2710
+
2711
+ log_event("stripe_webhook_complete", event_type=event['type'], data=event['data']['object'])
2712
+
2713
+ return {"status": "success"}
2714
+ except Exception as e:
2715
+ raise HTTPException(status_code=400, detail=str(e))
2716
+
2717
+
2718
+
2719
+ def generate_referral_link(referral_code: str) -> str:
2720
+ """
2721
+ Generate a referral link using the user's referral code.
2722
+ """
2723
+ require_cloud("Referral program")
2724
+ base_url = f"http://{WEB_URL}/login"
2725
+ return f"{base_url}?ref={referral_code}"
2726
+
2727
+ @router.post("/send-invite")
2728
+ def send_invite(
2729
+ invitee_emails: str = Form(...),
2730
+ group_id: str = Form(None),
2731
+ current_user: models.User = Depends(get_current_user),
2732
+ ):
2733
+ """
2734
+ Send an email invitation with the user's referral code.
2735
+ If group_id is provided, track the invitation for group auto-invite on sign-up.
2736
+ """
2737
+ require_cloud("Referral program")
2738
+ invitee_emails = invitee_emails.split(',')
2739
+ with compair.Session() as session:
2740
+
2741
+ referral_link = generate_referral_link(current_user.referral_code)
2742
+
2743
+ # Send email notification
2744
+ emailer.connect()
2745
+ for invitee_email in invitee_emails:
2746
+ # Track pending group invitation for this email if group_id is provided
2747
+ if group_id:
2748
+ # Store a "pending" group invitation for this email
2749
+ token = secrets.token_urlsafe(32).lower()
2750
+ invitation = models.GroupInvitation(
2751
+ group_id=group_id,
2752
+ inviter_id=current_user.user_id,
2753
+ token=token,
2754
+ email=invitee_email,
2755
+ datetime_expiration=datetime.utcnow() + timedelta(days=7),
2756
+ status="pending"
2757
+ )
2758
+ session.add(invitation)
2759
+ session.commit()
2760
+ # Send email notification
2761
+ emailer.send(
2762
+ subject="You're Invited to Compair!",
2763
+ sender=EMAIL_USER,
2764
+ receivers=[invitee_email],
2765
+ html=INDIVIDUAL_INVITATION_TEMPLATE.replace(
2766
+ "{{inviter_name}}", current_user.name
2767
+ ).replace(
2768
+ "{{referral_link}}", referral_link,
2769
+ )
2770
+ )
2771
+
2772
+ return {"message": "Invitation sent successfully"}
2773
+
2774
+ @router.post("/get-customer-portal")
2775
+ def get_customer_portal(
2776
+ current_user: models.User = Depends(get_current_user),
2777
+ billing: BillingProvider = Depends(get_billing),
2778
+ settings: Settings = Depends(get_settings_dependency),
2779
+ ):
2780
+ """Generate a billing portal session link through the billing provider."""
2781
+
2782
+ if not settings.billing_enabled:
2783
+ raise HTTPException(status_code=501, detail="Billing is not available in this edition.")
2784
+
2785
+ return_url = f"http://{WEB_URL}/home" if WEB_URL else "https://compair.sh/home"
2786
+
2787
+ try:
2788
+ with compair.Session() as session:
2789
+ user = session.query(models.User).filter(models.User.user_id == current_user.user_id).first()
2790
+ if not user or not user.stripe_customer_id:
2791
+ raise HTTPException(status_code=400, detail="No billing profile found for this user.")
2792
+
2793
+ try:
2794
+ portal_url = billing.create_customer_portal(
2795
+ customer_id=user.stripe_customer_id,
2796
+ return_url=return_url,
2797
+ )
2798
+ except NotImplementedError as exc:
2799
+ raise HTTPException(status_code=501, detail=str(exc)) from exc
2800
+
2801
+ return {"url": portal_url}
2802
+ except HTTPException:
2803
+ raise
2804
+ except Exception as exc:
2805
+ raise HTTPException(status_code=500, detail=str(exc)) from exc
2806
+
2807
+
2808
+
2809
+
2810
+ @router.post("/upload/profile")
2811
+ async def upload_profile_image(
2812
+ upload_type: str = Form(...),
2813
+ file: UploadFile = File(...),
2814
+ current_user: models.User = Depends(get_current_user)
2815
+ ) -> Mapping[str, str]:
2816
+ if upload_type!='profile':
2817
+ raise HTTPException(status_code=400, detail="Invalid upload type")
2818
+
2819
+ allowed_types = ["image/jpg", "image/jpeg", "image/png", "image/webp"]
2820
+ if file.content_type not in allowed_types:
2821
+ raise HTTPException(status_code=400, detail="Invalid file type")
2822
+
2823
+ # Upload to Cloudflare Images
2824
+ headers = {
2825
+ "Authorization": f"Bearer {CLOUDFLARE_IMAGES_TOKEN}"
2826
+ }
2827
+ data = {
2828
+ "requireSignedURLs": "false"
2829
+ }
2830
+ files = {
2831
+ "file": (file.filename, await file.read(), file.content_type)
2832
+ }
2833
+
2834
+ async with httpx.AsyncClient() as client:
2835
+ response = await client.post(
2836
+ CLOUDFLARE_IMAGES_UPLOAD_URL,
2837
+ headers=headers,
2838
+ data=data,
2839
+ files=files,
2840
+ timeout=30
2841
+ )
2842
+ if response.status_code != 200:
2843
+ raise HTTPException(status_code=500, detail="Cloudflare Images upload failed")
2844
+
2845
+ result = response.json()
2846
+ if not result.get("success"):
2847
+ raise HTTPException(status_code=500, detail="Cloudflare Images upload failed")
2848
+
2849
+ image_id = result["result"]["id"]
2850
+
2851
+ with compair.Session() as session:
2852
+ user = session.query(models.User).filter(models.User.user_id == current_user.user_id).first()
2853
+ user.profile_image = image_id
2854
+ session.commit()
2855
+
2856
+ # Read file bytes to get size
2857
+ file_bytes = await file.read()
2858
+ file_size = len(file_bytes)
2859
+
2860
+ # Log the upload event
2861
+ log_event(
2862
+ "cloudflare_upload",
2863
+ upload_type="profile_image",
2864
+ user_id=current_user.user_id,
2865
+ file_size=file_size,
2866
+ file_key=image_id,
2867
+ content_type=file.content_type
2868
+ )
2869
+
2870
+ return {"image_id": image_id}
2871
+
2872
+
2873
+ @router.get("/get_profile_image")
2874
+ def get_profile_image(
2875
+ variant: str = Query("public", enum=["public", "avatar", "preview"]),
2876
+ current_user: models.User = Depends(get_current_user)
2877
+ ) -> Mapping[str, str]:
2878
+ with compair.Session() as session:
2879
+ image_id = current_user.profile_image
2880
+ if not image_id:
2881
+ raise HTTPException(status_code=400, detail="No profile image found")
2882
+ image_url = f"{CLOUDFLARE_IMAGES_BASE_URL}/{image_id}/{variant}"
2883
+ return {"url": image_url}
2884
+
2885
+
2886
+ @router.post("/upload/file")
2887
+ async def upload_file(
2888
+ document_id: str = Form(...),
2889
+ upload_type: str = Form(...),
2890
+ file: UploadFile = File(...),
2891
+ current_user: models.User = Depends(get_current_user),
2892
+ storage: StorageProvider = Depends(get_storage),
2893
+ ) -> Mapping[str, str]:
2894
+ if upload_type != 'file':
2895
+ raise HTTPException(status_code=400, detail="Invalid upload type")
2896
+
2897
+ allowed_types = ["application/pdf"]#, "document/docx"]
2898
+ if file.content_type not in allowed_types:
2899
+ raise HTTPException(status_code=400, detail="Invalid file type")
2900
+
2901
+ file_name = file.filename
2902
+ file_name_hash = hashlib.sha256(f"{file_name}".encode()).hexdigest()
2903
+ file_key = f"{upload_type}/{current_user.user_id}/{file_name_hash}/{document_id}"
2904
+
2905
+ # Read file bytes to get size and reset file pointer before upload
2906
+ file_bytes = await file.read()
2907
+ file_size = len(file_bytes)
2908
+ await file.seek(0)
2909
+
2910
+ content_type = file.content_type or "application/octet-stream"
2911
+
2912
+ try:
2913
+ upload_url = storage.put_file(file_key, file.file, content_type)
2914
+ except NotImplementedError as exc:
2915
+ raise HTTPException(status_code=501, detail=str(exc)) from exc
2916
+ except Exception as exc: # pragma: no cover - surfaced to client
2917
+ raise HTTPException(status_code=500, detail=str(exc)) from exc
2918
+
2919
+ with compair.Session() as session:
2920
+ # Fetch the document
2921
+ document = session.query(models.Document).filter(models.Document.document_id == document_id).first()
2922
+ if not document:
2923
+ raise HTTPException(status_code=404, detail="Document not found")
2924
+
2925
+ document.file_key = file_key
2926
+ session.commit()
2927
+
2928
+ # Log the upload event
2929
+ log_event(
2930
+ "cloudflare_upload",
2931
+ upload_type="document_file",
2932
+ user_id=current_user.user_id,
2933
+ document_id=document_id,
2934
+ file_size=file_size,
2935
+ file_key=file_key,
2936
+ content_type=content_type
2937
+ )
2938
+
2939
+ return {"url": upload_url}
2940
+
2941
+
2942
+ @router.post("/upload/group")
2943
+ async def upload_group_image(
2944
+ group_id: str,
2945
+ upload_type: str,
2946
+ file: UploadFile = File(...)
2947
+ ) -> Mapping[str, str]:
2948
+ if upload_type!='group':
2949
+ raise HTTPException(status_code=400, detail="Invalid upload type")
2950
+
2951
+ allowed_types = ["image/jpg", "image/jpeg", "image/png", "image/webp"]
2952
+ if file.content_type not in allowed_types:
2953
+ raise HTTPException(status_code=400, detail="Invalid file type")
2954
+
2955
+ # Upload to Cloudflare Images
2956
+ headers = {
2957
+ "Authorization": f"Bearer {CLOUDFLARE_IMAGES_TOKEN}"
2958
+ }
2959
+ data = {
2960
+ "requireSignedURLs": "false"
2961
+ }
2962
+ files = {
2963
+ "file": (file.filename, await file.read(), file.content_type)
2964
+ }
2965
+
2966
+ async with httpx.AsyncClient() as client:
2967
+ response = await client.post(
2968
+ CLOUDFLARE_IMAGES_UPLOAD_URL,
2969
+ headers=headers,
2970
+ data=data,
2971
+ files=files,
2972
+ timeout=30
2973
+ )
2974
+ if response.status_code != 200:
2975
+ raise HTTPException(status_code=500, detail="Cloudflare Images upload failed")
2976
+
2977
+ result = response.json()
2978
+ if not result.get("success"):
2979
+ raise HTTPException(status_code=500, detail="Cloudflare Images upload failed")
2980
+
2981
+ image_id = result["result"]["id"]
2982
+
2983
+ with compair.Session() as session:
2984
+ group = session.query(models.Group).filter(models.Group.group_id == group_id).first()
2985
+ group.group_image = image_id
2986
+ session.commit()
2987
+
2988
+ # Read file bytes to get size
2989
+ file_bytes = await file.read()
2990
+ file_size = len(file_bytes)
2991
+
2992
+ # Log the upload event
2993
+ log_event(
2994
+ "cloudflare_upload",
2995
+ upload_type="group_image",
2996
+ group_id=group_id,
2997
+ file_size=file_size,
2998
+ file_key=image_id,
2999
+ content_type=file.content_type
3000
+ )
3001
+
3002
+ @router.post("/send-feature-announcement")
3003
+ def send_feature_announcement(admin_key: str = Header(None)):
3004
+ """
3005
+ Manually trigger a feature announcement email to churned users.
3006
+ Requires an admin API key.
3007
+ """
3008
+ if not IS_CLOUD or send_feature_announcement_task is None:
3009
+ raise HTTPException(status_code=501, detail="Feature announcements require the Compair Cloud edition.")
3010
+ if admin_key != ADMIN_API_KEY:
3011
+ raise HTTPException(status_code=403, detail="Unauthorized")
3012
+
3013
+ send_feature_announcement_task.delay() # Run email sending as an async Celery task
3014
+
3015
+ return {"message": "Feature announcement email campaign triggered successfully."}
3016
+
3017
+
3018
+ @router.post("/retrial/{user_id}")
3019
+ def grant_retrial(user_id: str):
3020
+ """
3021
+ Reactivate a churned user's trial for 2 weeks.
3022
+ """
3023
+ if not IS_CLOUD:
3024
+ raise HTTPException(status_code=501, detail="Re-trials require the Compair Cloud edition.")
3025
+ with Session() as session:
3026
+ user = session.query(models.User).filter(models.User.user_id == user_id).first()
3027
+
3028
+ if not user or user.status != "suspended":
3029
+ raise HTTPException(status_code=400, detail="User not eligible for a re-trial.")
3030
+
3031
+ if user.retrial_count >= 1:
3032
+ raise HTTPException(status_code=403, detail="Re-trial limit reached.")
3033
+
3034
+ # Reactivate trial
3035
+ user.status = "trial"
3036
+ user.status_change_date = datetime.now(timezone.utc)
3037
+ user.trial_expiration_date = datetime.now(timezone.utc) + timedelta(days=14)
3038
+ user.retrial_count += 1 # Increment re-trial count
3039
+ user.last_retrial_date = datetime.now(timezone.utc)
3040
+ session.commit()
3041
+
3042
+ return {"message": "Your re-trial has been activated!"}
3043
+
3044
+ @router.post("/documents/{document_id}/notes")
3045
+ def create_note(
3046
+ document_id: str,
3047
+ content: str = Form(...),
3048
+ group_id: str = Form(None),
3049
+ current_user: models.User = Depends(get_current_user)
3050
+ ):
3051
+ """Create a Note for a Document. User must be in the Document's group. Also chunk/embed Note content."""
3052
+ with compair.Session() as session:
3053
+ document = session.query(models.Document).filter(models.Document.document_id == document_id).first()
3054
+ if not document:
3055
+ raise HTTPException(status_code=404, detail="Document not found")
3056
+ # Check group membership
3057
+ doc_group_ids = [g.group_id for g in document.groups]
3058
+ user_group_ids = [g.group_id for g in current_user.groups]
3059
+ if not set(doc_group_ids) & set(user_group_ids):
3060
+ raise HTTPException(status_code=403, detail="User not in document's group")
3061
+ note = models.Note(
3062
+ document_id=document_id,
3063
+ author_id=current_user.user_id,
3064
+ group_id=group_id or (doc_group_ids[0] if doc_group_ids else None),
3065
+ content=content,
3066
+ datetime_created=datetime.now(timezone.utc)
3067
+ )
3068
+ session.add(note)
3069
+ session.commit()
3070
+ session.refresh(note)
3071
+ # Chunk and embed Note content
3072
+ note_text_chunks = chunk_text(content)
3073
+ embedder = Embedder()
3074
+ for text in note_text_chunks:
3075
+ chunk_hash = hash(text)
3076
+ embedding = create_embedding(embedder, text, user=current_user)
3077
+ note_chunk = models.Chunk(
3078
+ hash=str(chunk_hash),
3079
+ document_id=document_id,
3080
+ note_id=note.note_id,
3081
+ chunk_type="note",
3082
+ content=text,
3083
+ )
3084
+ note_chunk.embedding = embedding
3085
+ session.add(note_chunk)
3086
+ session.commit()
3087
+
3088
+ log_activity(
3089
+ session=session,
3090
+ user_id=current_user.user_id,
3091
+ group_id=note.group_id,
3092
+ action="create",
3093
+ object_id=note.note_id,
3094
+ object_name=f"Note on {document.title}",
3095
+ object_type="note"
3096
+ )
3097
+
3098
+ return schema.Note(
3099
+ note_id=note.note_id,
3100
+ document_id=note.document_id,
3101
+ author_id=note.author_id,
3102
+ group_id=note.group_id,
3103
+ content=note.content,
3104
+ datetime_created=note.datetime_created,
3105
+ author=schema.User(
3106
+ user_id=current_user.user_id,
3107
+ username=current_user.username,
3108
+ name=current_user.name,
3109
+ datetime_registered=current_user.datetime_registered,
3110
+ status=current_user.status,
3111
+ profile_image=current_user.profile_image,
3112
+ role=current_user.role,
3113
+ )
3114
+ )
3115
+
3116
+ @router.get("/documents/{document_id}/notes")
3117
+ def list_notes(
3118
+ document_id: str,
3119
+ current_user: models.User = Depends(get_current_user)
3120
+ ) -> list[schema.Note]:
3121
+ """List all Notes for a Document."""
3122
+ with compair.Session() as session:
3123
+ notes = session.query(models.Note).filter(models.Note.document_id == document_id).all()
3124
+ result = []
3125
+ for note in notes:
3126
+ author = session.query(models.User).filter(models.User.user_id == note.author_id).first()
3127
+ result.append(schema.Note(
3128
+ note_id=note.note_id,
3129
+ document_id=note.document_id,
3130
+ author_id=note.author_id,
3131
+ group_id=note.group_id,
3132
+ content=note.content,
3133
+ datetime_created=note.datetime_created,
3134
+ author=schema.User(
3135
+ user_id=author.user_id,
3136
+ username=author.username,
3137
+ name=author.name,
3138
+ datetime_registered=author.datetime_registered,
3139
+ status=author.status,
3140
+ profile_image=author.profile_image,
3141
+ role=author.role,
3142
+ ) if author else None
3143
+ ))
3144
+ return result
3145
+
3146
+ @router.get("/notes/{note_id}")
3147
+ def get_note(
3148
+ note_id: str,
3149
+ current_user: models.User = Depends(get_current_user)
3150
+ ) -> schema.Note:
3151
+ """Get a single Note by ID."""
3152
+ with compair.Session() as session:
3153
+ note = session.query(models.Note).filter(models.Note.note_id == note_id).first()
3154
+ if not note:
3155
+ raise HTTPException(status_code=404, detail="Note not found")
3156
+ author = session.query(models.User).filter(models.User.user_id == note.author_id).first()
3157
+ return schema.Note(
3158
+ note_id=note.note_id,
3159
+ document_id=note.document_id,
3160
+ author_id=note.author_id,
3161
+ group_id=note.group_id,
3162
+ content=note.content,
3163
+ datetime_created=note.datetime_created,
3164
+ author=schema.User(
3165
+ user_id=author.user_id,
3166
+ username=author.username,
3167
+ name=author.name,
3168
+ datetime_registered=author.datetime_registered,
3169
+ status=author.status,
3170
+ profile_image=author.profile_image,
3171
+ role=author.role,
3172
+ ) if author else None
3173
+ )
3174
+
3175
+ @router.get("/documents/{document_id}/file-url")
3176
+ def get_document_file_url(
3177
+ document_id: str,
3178
+ current_user: models.User = Depends(get_current_user),
3179
+ storage: StorageProvider = Depends(get_storage),
3180
+ ):
3181
+ """Return the file URL for a Document."""
3182
+ with compair.Session() as session:
3183
+ doc = session.query(models.Document).filter(models.Document.document_id == document_id).first()
3184
+ if not doc or not doc.file_key:
3185
+ raise HTTPException(status_code=404, detail="File not found for this document.")
3186
+ file_url = storage.build_url(doc.file_key)
3187
+ return {"file_url": file_url}
3188
+
3189
+ @router.get("/documents/{document_id}/image-url")
3190
+ def get_document_image_url(
3191
+ document_id: str,
3192
+ current_user: models.User = Depends(get_current_user),
3193
+ storage: StorageProvider = Depends(get_storage),
3194
+ ):
3195
+ """Return the preview image URL for a Document."""
3196
+ with compair.Session() as session:
3197
+ doc = session.query(models.Document).filter(models.Document.document_id == document_id).first()
3198
+ if not doc or not doc.image_key:
3199
+ raise HTTPException(status_code=404, detail="Image not found for this document.")
3200
+ image_url = storage.build_url(doc.image_key)
3201
+ return {"image_url": image_url}
3202
+
3203
+
3204
+ def sanitize_filename(filename: str) -> str:
3205
+ # Replace non-ASCII characters with underscore
3206
+ return re.sub(r'[^\x00-\x7F]+', '_', filename)
3207
+
3208
+
3209
+ @router.post("/documents/{document_id}/generate-download-token")
3210
+ def generate_download_token(
3211
+ document_id: str,
3212
+ current_user: models.User = Depends(get_current_user)
3213
+ ):
3214
+ # Check permissions as in your download endpoint
3215
+ with compair.Session() as session:
3216
+ doc = session.query(models.Document).filter(models.Document.document_id == document_id).first()
3217
+ if not doc or not doc.file_key:
3218
+ raise HTTPException(status_code=404, detail="File not found for this document.")
3219
+
3220
+ # Authorization check (same as download endpoint)
3221
+ if current_user.user_id == doc.author_id:
3222
+ pass
3223
+ elif doc.is_published:
3224
+ doc_group_ids = {g.group_id for g in doc.groups}
3225
+ user_group_ids = {g.group_id for g in current_user.groups}
3226
+ if not doc_group_ids & user_group_ids:
3227
+ raise HTTPException(status_code=403, detail="Not authorized to download this file.")
3228
+ else:
3229
+ raise HTTPException(status_code=403, detail="Not authorized to download this file.")
3230
+
3231
+ token = secrets.token_urlsafe(32)
3232
+ key = f"download_token:{token}"
3233
+ redis_client.setex(key, 300, document_id)
3234
+ print('Setting redis kv')
3235
+ print(key, document_id)
3236
+ return {"download_url": f"/documents/download/{token}"}
3237
+
3238
+
3239
+ @router.get("/documents/download/{token}")
3240
+ def download_document_with_token(
3241
+ token: str,
3242
+ storage: StorageProvider = Depends(get_storage),
3243
+ ):
3244
+ key = f"download_token:{token}"
3245
+ print(f'Retrieving redis kv with key {key}')
3246
+ document_id = redis_client.get(key).decode('utf-8') if redis_client.get(key) else None
3247
+ print(f'Value {document_id}')
3248
+ if not document_id:
3249
+ raise HTTPException(status_code=403, detail="Invalid or expired token")
3250
+ redis_client.delete(key)
3251
+ with compair.Session() as session:
3252
+ doc = session.query(models.Document).filter(models.Document.document_id == document_id).first()
3253
+ if not doc or not doc.file_key:
3254
+ raise HTTPException(status_code=404, detail="File not found for this document.")
3255
+
3256
+ safe_title = sanitize_filename(doc.title or 'file')
3257
+ try:
3258
+ file_obj, content_type = storage.get_file(doc.file_key)
3259
+ except FileNotFoundError as exc:
3260
+ raise HTTPException(status_code=404, detail="Stored file could not be located.") from exc
3261
+
3262
+ return StreamingResponse(
3263
+ file_obj,
3264
+ media_type=content_type,
3265
+ headers={
3266
+ "Content-Disposition": f"attachment; filename={safe_title or 'file'}"
3267
+ }
3268
+ )
3269
+
3270
+
3271
+ @router.get("/documents/{document_id}/download")
3272
+ def download_document_file(
3273
+ document_id: str,
3274
+ current_user: models.User = Depends(get_current_user),
3275
+ storage: StorageProvider = Depends(get_storage),
3276
+ ):
3277
+ with compair.Session() as session:
3278
+ doc = session.query(models.Document).filter(models.Document.document_id == document_id).first()
3279
+
3280
+ safe_title = sanitize_filename(doc.title or 'file')
3281
+
3282
+ if not doc or not doc.file_key:
3283
+ raise HTTPException(status_code=404, detail="File not found for this document.")
3284
+
3285
+ # Authorization check
3286
+ if current_user.user_id == doc.author_id:
3287
+ pass
3288
+ elif doc.is_published:
3289
+ # Check group membership
3290
+ doc_group_ids = {g.group_id for g in doc.groups}
3291
+ user_group_ids = {g.group_id for g in current_user.groups}
3292
+ if not doc_group_ids & user_group_ids:
3293
+ raise HTTPException(status_code=403, detail="Not authorized to download this file.")
3294
+
3295
+ # Fetch file from R2 and stream to user
3296
+ try:
3297
+ file_obj, content_type = storage.get_file(doc.file_key)
3298
+ except FileNotFoundError as exc:
3299
+ raise HTTPException(status_code=404, detail="Stored file could not be located.") from exc
3300
+
3301
+ return StreamingResponse(
3302
+ file_obj,
3303
+ media_type=content_type,
3304
+ headers={
3305
+ "Content-Disposition": f"attachment; filename={safe_title or 'file'}"
3306
+ }
3307
+ )
3308
+
3309
+
3310
+ @router.post("/help-request")
3311
+ def submit_help_request(
3312
+ content: str = Form(...),
3313
+ current_user: models.User = Depends(get_current_user)
3314
+ ):
3315
+ with compair.Session() as session:
3316
+ help_request = models.HelpRequest(
3317
+ user_id=current_user.user_id,
3318
+ content=content,
3319
+ datetime_created=datetime.now(timezone.utc)
3320
+ )
3321
+ session.add(help_request)
3322
+ session.commit()
3323
+ if not IS_CLOUD:
3324
+ raise HTTPException(status_code=501, detail="Help request emails require the Compair Cloud edition.")
3325
+ send_help_request_email.delay(help_request.request_id)
3326
+ return {"message": "Your request has been submitted. Our team will get back to you soon."}
3327
+
3328
+
3329
+ @router.post("/deactivate-account")
3330
+ def submit_deactivate_request(
3331
+ notice: str = Form(...),
3332
+ current_user: models.User = Depends(get_current_user)
3333
+ ):
3334
+ with compair.Session() as session:
3335
+ deactivate_request = models.DeactivateRequest(
3336
+ user_id=current_user.user_id,
3337
+ notice=notice,
3338
+ datetime_created=datetime.now(timezone.utc)
3339
+ )
3340
+ session.add(deactivate_request)
3341
+ session.commit()
3342
+ if not IS_CLOUD:
3343
+ raise HTTPException(status_code=501, detail="Deactivate request emails require the Compair Cloud edition.")
3344
+ send_deactivate_request_email.delay(deactivate_request.request_id)
3345
+ return {"message": f"We’ve received your request and will delete your account and data shortly. If you change your mind, reach out within 24 hours at {EMAIL_USER}."}
3346
+
3347
+
3348
+ def create_fastapi_app():
3349
+ """Backwards-compatible app factory for running this module directly."""
3350
+ from fastapi import FastAPI
3351
+
3352
+ fastapi_app = FastAPI()
3353
+ fastapi_app.include_router(router)
3354
+ return fastapi_app
3355
+
3356
+
3357
+ app = create_fastapi_app()