Skip to content

Commit 6d409ff

Browse files
committed
feat(chat): add conversation sharing with SharePersistence and UI flow
Introduce a SharePersistence layer, share/unshare API endpoints, and a SQL-backed implementation, and wire the React UI (ShareButton, ChatHistory, HistoryStore, ConversationGuard) through a new shares API client so conversations can be generated, fetched, and revoked via shareable links. Adds unit tests for share persistence and SQL conversation handling.
1 parent 228a4d6 commit 6d409ff

47 files changed

Lines changed: 1647 additions & 443 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

package-lock.json

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

packages/ragbits-chat/src/ragbits/chat/api.py

Lines changed: 172 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,14 @@
1414
from fastapi import FastAPI, HTTPException, Request, UploadFile, status
1515
from fastapi.exceptions import RequestValidationError
1616
from fastapi.middleware.cors import CORSMiddleware
17-
from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse, RedirectResponse, StreamingResponse
17+
from fastapi.responses import (
18+
HTMLResponse,
19+
JSONResponse,
20+
PlainTextResponse,
21+
RedirectResponse,
22+
Response,
23+
StreamingResponse,
24+
)
1825
from fastapi.staticfiles import StaticFiles
1926
from pydantic import BaseModel
2027

@@ -32,13 +39,19 @@
3239
ChatResponseUnion,
3340
ChunkedContent,
3441
ConfigResponse,
42+
ConversationDetail,
43+
ConversationMeta,
44+
ConversationShareResponse,
3545
FeedbackConfig,
3646
FeedbackItem,
3747
FeedbackRequest,
3848
Image,
3949
ImageResponse,
4050
OAuth2ProviderConfig,
51+
ShareConversationRequest,
4152
)
53+
from ragbits.chat.persistence.share import SharePersistence
54+
from ragbits.chat.persistence.sql import SQLHistoryPersistence
4255
from ragbits.core.audit.metrics import record_metric
4356
from ragbits.core.audit.metrics.base import MetricType
4457
from ragbits.core.audit.traces import trace
@@ -61,6 +74,7 @@ def __init__(
6174
debug_mode: bool = False,
6275
auth_backend: AuthenticationBackend | type[AuthenticationBackend] | str | None = None,
6376
theme_path: str | None = None,
77+
share_persistence: SharePersistence | None = None,
6478
) -> None:
6579
"""
6680
Initialize the RagbitsAPI.
@@ -73,13 +87,15 @@ def __init__(
7387
debug_mode: Flag enabling debug tools in the default UI
7488
auth_backend: Authentication backend for user authentication. If None, no authentication required.
7589
theme_path: Path to a JSON file containing HeroUI theme configuration from heroui.com/themes
90+
share_persistence: Optional share persistence for conversation sharing. Requires auth_backend.
7691
"""
7792
self.chat_interface: ChatInterface = self._load_chat_interface(chat_interface)
7893
self.dist_dir = Path(ui_build_dir) if ui_build_dir else Path(__file__).parent / "ui-build"
7994
self.cors_origins = cors_origins or []
8095
self.debug_mode = debug_mode
8196
self.auth_backend = self._load_auth_backend(auth_backend)
8297
self.theme_path = Path(theme_path) if theme_path else None
98+
self.share_persistence = share_persistence
8399

84100
self.frontend_base_url = BASE_URL
85101

@@ -269,6 +285,7 @@ async def config() -> JSONResponse:
269285
oauth2_providers=oauth2_providers,
270286
),
271287
supports_upload=self.chat_interface.upload_handler is not None,
288+
sharing=self.share_persistence is not None,
272289
)
273290

274291
return JSONResponse(content=config_response.model_dump())
@@ -301,12 +318,166 @@ async def theme() -> PlainTextResponse:
301318
logger.error(f"Error serving theme: {e}")
302319
raise HTTPException(status_code=500, detail="Error loading theme") from e
303320

321+
if self.share_persistence and self.auth_backend:
322+
self._setup_share_routes()
323+
304324
@self.app.get("/{full_path:path}", response_class=HTMLResponse)
305325
async def root() -> HTMLResponse:
306326
index_file = self.dist_dir / "index.html"
307327
with open(str(index_file)) as file:
308328
return HTMLResponse(content=file.read())
309329

330+
def _setup_share_routes(self) -> None: # noqa: PLR0915
331+
"""Register routes for conversation sharing. Requires auth_backend and share_persistence."""
332+
share = self.share_persistence
333+
assert share is not None # noqa: S101
334+
335+
history_persistence = self.chat_interface.history_persistence
336+
if not isinstance(history_persistence, SQLHistoryPersistence):
337+
logger.warning("Share routes require SQLHistoryPersistence; sharing disabled.")
338+
return
339+
340+
history: SQLHistoryPersistence = history_persistence
341+
342+
async def _require_user(request: Request) -> User:
343+
user = await self.require_authenticated_user(request)
344+
if not user:
345+
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required")
346+
return user
347+
348+
async def _require_owner(conversation_id: str, user: User) -> None:
349+
owner = await history.get_conversation_owner(conversation_id)
350+
if owner != user.user_id:
351+
raise HTTPException(status_code=404, detail="Conversation not found.")
352+
353+
def _recipient_identifiers(user: User) -> list[str]:
354+
"""Return all identifiers a user can be addressed by when sharing.
355+
356+
The share UI lowercases input, so recipients are stored lowercased.
357+
We mirror that here to match whether the owner shared with a user_id,
358+
a username, or an email.
359+
"""
360+
ids: list[str] = []
361+
seen: set[str] = set()
362+
for raw in (user.user_id, user.username, user.email):
363+
if not raw:
364+
continue
365+
normalized = raw.strip().lower()
366+
if normalized and normalized not in seen:
367+
seen.add(normalized)
368+
ids.append(normalized)
369+
return ids
370+
371+
@self.app.get("/api/conversations")
372+
async def list_conversations(request: Request) -> list[ConversationMeta]:
373+
user = await _require_user(request)
374+
owned = await history.list_conversations(user.user_id)
375+
owned_ids = {c["id"] for c in owned}
376+
shared_rows = await share.list_shared_with_me(_recipient_identifiers(user))
377+
shared_rows = [r for r in shared_rows if r["conversation_id"] not in owned_ids]
378+
379+
all_ids = [c["id"] for c in owned] + [r["conversation_id"] for r in shared_rows]
380+
summaries = await history.get_conversation_summaries(all_ids)
381+
382+
owned_metas = [
383+
ConversationMeta(
384+
conversation_id=c["id"],
385+
created_at=str(c["created_at"]) if c["created_at"] else "",
386+
summary=summaries.get(c["id"]),
387+
)
388+
for c in owned
389+
]
390+
shared_metas = [
391+
ConversationMeta(
392+
conversation_id=r["conversation_id"],
393+
created_at=str(r["shared_at"]) if r["shared_at"] else "",
394+
summary=summaries.get(r["conversation_id"]),
395+
is_shared=True,
396+
shared_by=r["owner_id"],
397+
)
398+
for r in shared_rows
399+
]
400+
return owned_metas + shared_metas
401+
402+
@self.app.get("/api/conversations/{conversation_id}")
403+
async def get_conversation(request: Request, conversation_id: str) -> ConversationDetail:
404+
user = await _require_user(request)
405+
owner = await history.get_conversation_owner(conversation_id)
406+
is_shared = False
407+
shared_by: str | None = None
408+
if owner == user.user_id:
409+
pass
410+
elif await share.can_access(conversation_id, _recipient_identifiers(user)):
411+
is_shared = True
412+
shared_by = owner
413+
else:
414+
raise HTTPException(status_code=404, detail="Conversation not found.")
415+
interactions = await history.get_conversation_interactions(conversation_id)
416+
return ConversationDetail(
417+
conversation_id=conversation_id,
418+
messages=interactions,
419+
is_shared=is_shared,
420+
shared_by=shared_by,
421+
)
422+
423+
@self.app.delete("/api/conversations/{conversation_id}", status_code=204)
424+
async def delete_conversation(request: Request, conversation_id: str) -> Response:
425+
user = await _require_user(request)
426+
await _require_owner(conversation_id, user)
427+
await history.delete_conversation(conversation_id)
428+
return Response(status_code=204)
429+
430+
@self.app.get("/api/conversations/{conversation_id}/shares")
431+
async def list_shares(request: Request, conversation_id: str) -> list[ConversationShareResponse]:
432+
user = await _require_user(request)
433+
await _require_owner(conversation_id, user)
434+
shares = await share.get_shares(conversation_id, user.user_id)
435+
return [
436+
ConversationShareResponse(
437+
recipient=s["recipient"],
438+
shared_at=str(s["shared_at"]) if s["shared_at"] else "",
439+
)
440+
for s in shares
441+
]
442+
443+
@self.app.put("/api/conversations/{conversation_id}/shares")
444+
async def update_shares(
445+
request: Request, conversation_id: str, body: ShareConversationRequest
446+
) -> list[ConversationShareResponse]:
447+
user = await _require_user(request)
448+
await _require_owner(conversation_id, user)
449+
new_recipients = set(body.recipients)
450+
existing = await share.get_shares(conversation_id, user.user_id)
451+
existing_recipients = {s["recipient"] for s in existing}
452+
to_add = list(new_recipients - existing_recipients)
453+
to_remove = list(existing_recipients - new_recipients)
454+
if to_remove:
455+
await share.remove_shares(conversation_id, user.user_id, to_remove)
456+
if to_add:
457+
await share.set_shares(conversation_id, user.user_id, to_add)
458+
updated = await share.get_shares(conversation_id, user.user_id)
459+
return [
460+
ConversationShareResponse(
461+
recipient=s["recipient"],
462+
shared_at=str(s["shared_at"]) if s["shared_at"] else "",
463+
)
464+
for s in updated
465+
]
466+
467+
@self.app.delete("/api/conversations/{conversation_id}/shares/{recipient:path}", status_code=204)
468+
async def revoke_share(request: Request, conversation_id: str, recipient: str) -> Response:
469+
user = await _require_user(request)
470+
await _require_owner(conversation_id, user)
471+
await share.remove_shares(conversation_id, user.user_id, [recipient])
472+
return Response(status_code=204)
473+
474+
@self.app.delete("/api/shared/{conversation_id}", status_code=204)
475+
async def dismiss_share(request: Request, conversation_id: str) -> Response:
476+
user = await _require_user(request)
477+
if not await share.hide_share(conversation_id, _recipient_identifiers(user)):
478+
raise HTTPException(status_code=404, detail="Shared conversation not found.")
479+
return Response(status_code=204)
480+
310481
@staticmethod
311482
def _prepare_chat_context(
312483
request: ChatMessageRequest,

packages/ragbits-chat/src/ragbits/chat/interface/types.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -907,3 +907,41 @@ class ConfigResponse(BaseModel):
907907
conversation_history: bool = Field(default=False, description="Flag to enable conversation history")
908908
show_usage: bool = Field(default=False, description="Flag to enable usage statistics")
909909
authentication: AuthenticationConfig = Field(..., description="Authentication configuration")
910+
sharing: bool = Field(default=False, description="Flag to enable conversation sharing")
911+
912+
913+
# ---------------------------------------------------------------------------
914+
# Conversation sharing models
915+
# ---------------------------------------------------------------------------
916+
917+
918+
class ShareConversationRequest(BaseModel):
919+
"""Validated input for sharing a conversation."""
920+
921+
recipients: list[str] = Field(min_length=1, description="List of recipient identifiers (user IDs or emails)")
922+
923+
924+
class ConversationShareResponse(BaseModel):
925+
"""API response model for a single share recipient."""
926+
927+
recipient: str
928+
shared_at: str
929+
930+
931+
class ConversationMeta(BaseModel):
932+
"""Conversation metadata returned in list endpoints."""
933+
934+
conversation_id: str
935+
created_at: str
936+
summary: str | None = None
937+
is_shared: bool = False
938+
shared_by: str | None = None
939+
940+
941+
class ConversationDetail(BaseModel):
942+
"""Full conversation detail including messages."""
943+
944+
conversation_id: str
945+
messages: list[dict[str, Any]]
946+
is_shared: bool = False
947+
shared_by: str | None = None
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from ragbits.chat.persistence.base import HistoryPersistenceStrategy
2+
from ragbits.chat.persistence.share import SharePersistence
23

3-
__all__ = ["HistoryPersistenceStrategy"]
4+
__all__ = ["HistoryPersistenceStrategy", "SharePersistence"]

0 commit comments

Comments
 (0)