diff --git a/frontend/src/components/ProjectLockedModal.tsx b/frontend/src/components/ProjectLockedModal.tsx new file mode 100644 index 0000000..41fe99c --- /dev/null +++ b/frontend/src/components/ProjectLockedModal.tsx @@ -0,0 +1,124 @@ +import { useState } from "react"; +import { AlertTriangle, ArrowLeft, UserCheck } from "lucide-react"; +import { Button } from "@/components/ui/button"; +import { cn } from "@/lib/utils"; + +interface ProjectLockedModalProps { + lockedByEmail: string; + lockedAt?: string; + onTakeOver: () => void; + onGoBack: () => void; + className?: string; +} + +export function ProjectLockedModal({ + lockedByEmail, + lockedAt, + onTakeOver, + onGoBack, + className, +}: ProjectLockedModalProps) { + const [confirmTakeover, setConfirmTakeover] = useState(false); + + const formatLockedTime = (isoString?: string) => { + if (!isoString) return null; + try { + const date = new Date(isoString); + return date.toLocaleTimeString(undefined, { + hour: "numeric", + minute: "2-digit", + }); + } catch { + return null; + } + }; + + const lockedTime = formatLockedTime(lockedAt); + + return ( +
+
+ {/* Header */} +
+
+
+ +
+
+

+ Project in use +

+

+ Someone else is currently editing +

+
+
+
+ + {/* Content */} +
+

+ This project is being edited by{" "} + {lockedByEmail} + {lockedTime && ( + since {lockedTime} + )} + . +

+ + {!confirmTakeover ? ( +
+ + +
+ ) : ( +
+
+ +

+ Taking over will disconnect{" "} + {lockedByEmail} from the + project. They may lose unsaved changes. +

+
+
+ + +
+
+ )} +
+
+
+ ); +} diff --git a/frontend/src/components/ProjectMembers.tsx b/frontend/src/components/ProjectMembers.tsx new file mode 100644 index 0000000..0fb3b32 --- /dev/null +++ b/frontend/src/components/ProjectMembers.tsx @@ -0,0 +1,204 @@ +import { useState, useCallback, useEffect, type FormEvent } from "react"; +import { Loader2, UserPlus, X, Users } from "lucide-react"; +import { Button } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; +import { AGENT_CONFIG } from "../config/agent"; +import { cn } from "@/lib/utils"; + +interface ProjectMember { + user_sub: string | null; + user_email: string; + added_at: string | null; +} + +interface ProjectMembersProps { + projectId: string; + className?: string; +} + +export function ProjectMembers({ projectId, className }: ProjectMembersProps) { + const [members, setMembers] = useState([]); + const [loading, setLoading] = useState(true); + const [error, setError] = useState(null); + const [newEmail, setNewEmail] = useState(""); + const [adding, setAdding] = useState(false); + const [removingEmail, setRemovingEmail] = useState(null); + + const fetchMembers = useCallback(async () => { + try { + setLoading(true); + setError(null); + const res = await fetch( + `${AGENT_CONFIG.HTTP_URL}api/projects/${projectId}/members`, + { credentials: "include" } + ); + if (!res.ok) { + throw new Error("Failed to load members"); + } + const data = await res.json(); + setMembers(data.members || []); + } catch (e) { + setError(e instanceof Error ? e.message : "Failed to load members"); + } finally { + setLoading(false); + } + }, [projectId]); + + useEffect(() => { + fetchMembers(); + }, [fetchMembers]); + + const handleAddMember = async (e: FormEvent) => { + e.preventDefault(); + const email = newEmail.trim().toLowerCase(); + if (!email || !email.includes("@")) return; + + try { + setAdding(true); + setError(null); + const res = await fetch( + `${AGENT_CONFIG.HTTP_URL}api/projects/${projectId}/members`, + { + method: "POST", + credentials: "include", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ email }), + } + ); + if (!res.ok) { + const data = await res.json().catch(() => ({})); + throw new Error(data.error || "Failed to add member"); + } + setNewEmail(""); + await fetchMembers(); + } catch (e) { + setError(e instanceof Error ? e.message : "Failed to add member"); + } finally { + setAdding(false); + } + }; + + const handleRemoveMember = async (userSub: string | null, userEmail: string) => { + if (members.length <= 1) return; + + try { + setRemovingEmail(userEmail); + setError(null); + // Use by-email endpoint for pending members (null user_sub), otherwise by user_sub + const endpoint = userSub + ? `${AGENT_CONFIG.HTTP_URL}api/projects/${projectId}/members/${encodeURIComponent(userSub)}` + : `${AGENT_CONFIG.HTTP_URL}api/projects/${projectId}/members/by-email/${encodeURIComponent(userEmail)}`; + const res = await fetch(endpoint, { + method: "DELETE", + credentials: "include", + }); + if (!res.ok) { + const data = await res.json().catch(() => ({})); + throw new Error(data.error || "Failed to remove member"); + } + await fetchMembers(); + } catch (e) { + setError(e instanceof Error ? e.message : "Failed to remove member"); + } finally { + setRemovingEmail(null); + } + }; + + const canRemove = members.length > 1; + + return ( +
+
+ + People with access +
+ + {loading ? ( +
+ +
+ ) : ( + <> + {/* Member list */} +
+ {members.map((member) => ( +
+
+
+ {member.user_email.charAt(0).toUpperCase()} +
+ + {member.user_email} + + {!member.user_sub && ( + + Pending + + )} +
+ +
+ ))} +
+ + {/* Add member form */} +
+ setNewEmail(e.target.value)} + disabled={adding} + className="flex-1" + /> + +
+ + {/* Error message */} + {error && ( +

+ {error} +

+ )} + + )} +
+ ); +} diff --git a/frontend/src/components/SessionClaimedModal.tsx b/frontend/src/components/SessionClaimedModal.tsx new file mode 100644 index 0000000..3b0d848 --- /dev/null +++ b/frontend/src/components/SessionClaimedModal.tsx @@ -0,0 +1,61 @@ +import { UserX } from "lucide-react"; +import { Button } from "@/components/ui/button"; +import { cn } from "@/lib/utils"; + +interface SessionClaimedModalProps { + claimedByEmail?: string; + onDismiss: () => void; + className?: string; +} + +export function SessionClaimedModal({ + claimedByEmail, + onDismiss, + className, +}: SessionClaimedModalProps) { + return ( +
+
+ {/* Header */} +
+
+
+ +
+
+

+ Session ended +

+

+ Another user took over this project +

+
+
+
+ + {/* Content */} +
+

+ Your editing session was taken over + {claimedByEmail && ( + <> + {" "} + by {claimedByEmail} + + )} + . You'll be redirected to the project list. +

+ + +
+
+
+ ); +} diff --git a/frontend/src/hooks/useMessageBus.ts b/frontend/src/hooks/useMessageBus.ts index af762cb..789e14e 100644 --- a/frontend/src/hooks/useMessageBus.ts +++ b/frontend/src/hooks/useMessageBus.ts @@ -20,6 +20,7 @@ interface UseMessageBusReturn { isConnected: boolean; error: string | null; connect: () => Promise; + connectWithExtra: (extra: Record) => Promise; disconnect: () => void; send: (type: MessageType, payload: Record) => void; } @@ -116,7 +117,7 @@ export const useMessageBus = ({ }; }, []); - const connect = useCallback(async () => { + const connectInner = useCallback(async (initExtra?: Record) => { if (!messageBusRef.current) { throw new Error("MessageBus not initialized"); } @@ -134,7 +135,8 @@ export const useMessageBus = ({ prevParams.wsUrl === nextParams.wsUrl && prevParams.sessionId === nextParams.sessionId; - if (isConnected && webSocketRef.current && sameParams) { + // Skip "already connected" check when initExtra is provided (force reconnect) + if (!initExtra && isConnected && webSocketRef.current && sameParams) { console.log("Already connected with same parameters, skipping..."); return; } @@ -158,7 +160,8 @@ export const useMessageBus = ({ webSocketRef.current = createWebSocketBus( wsUrl, messageBusRef.current, - sessionId + sessionId, + initExtra ); lastConnectParamsRef.current = nextParams; await webSocketRef.current.connect(); @@ -171,6 +174,14 @@ export const useMessageBus = ({ } }, [wsUrl, sessionId, isConnected]); + const connect = useCallback(async () => { + await connectInner(); + }, [connectInner]); + + const connectWithExtra = useCallback(async (extra: Record) => { + await connectInner(extra); + }, [connectInner]); + const disconnect = useCallback(() => { if (webSocketRef.current) { webSocketRef.current.disconnect(); @@ -228,6 +239,7 @@ export const useMessageBus = ({ isConnected, error, connect, + connectWithExtra, disconnect, send, }; diff --git a/frontend/src/screens/Create/index.tsx b/frontend/src/screens/Create/index.tsx index 83bf40d..78b8e5a 100644 --- a/frontend/src/screens/Create/index.tsx +++ b/frontend/src/screens/Create/index.tsx @@ -7,6 +7,7 @@ import { Play, RotateCcw, TabletIcon, + Users, X, } from "lucide-react"; import { @@ -20,6 +21,9 @@ import { type JsonObject, type RuntimeErrorPayload, } from "../../types/messages"; +import { ProjectLockedModal } from "@/components/ProjectLockedModal"; +import { ProjectMembers } from "@/components/ProjectMembers"; +import { SessionClaimedModal } from "@/components/SessionClaimedModal"; import { useCallback, useEffect, @@ -220,6 +224,14 @@ const Create = () => { const [pendingImages, setPendingImages] = useState([]); const [attachmentError, setAttachmentError] = useState(null); const fileInputRef = useRef(null); + const [projectLockedInfo, setProjectLockedInfo] = useState<{ + email: string; + lockedAt?: string; + } | null>(null); + const [sessionClaimed, setSessionClaimed] = useState<{ + byEmail?: string; + } | null>(null); + const [showMembers, setShowMembers] = useState(false); type ToolRun = { runId: string; @@ -525,6 +537,16 @@ const Create = () => { }, [MessageType.ERROR]: (message: Message) => { + // Check for project_locked error + const md = asObj(message.data); + if (md && md.code === "project_locked") { + const lockedBy = md.locked_by as { email?: string; at?: string } | undefined; + setProjectLockedInfo({ + email: lockedBy?.email || "another user", + lockedAt: lockedBy?.at, + }); + return; + } setMessages((prev) => [ ...prev, { @@ -538,6 +560,13 @@ const Create = () => { ]); }, + [MessageType.SESSION_CLAIMED]: (message: Message) => { + const md = asObj(message.data); + setSessionClaimed({ + byEmail: md?.claimed_by_email as string | undefined, + }); + }, + [MessageType.AGENT_PARTIAL]: (message: Message) => { const text = message.data.text; const id = message.id; @@ -1071,7 +1100,7 @@ const Create = () => { return combined; }, [messages, toolRunsByAssistantMsgId, reasoningByAssistantMsgId, resolvedSessionId]); - const { isConnecting, isConnected, error, connect, send } = useMessageBus({ + const { isConnecting, isConnected, error, connect, connectWithExtra, send } = useMessageBus({ wsUrl: AGENT_CONFIG.WS_URL, sessionId: resolvedSessionId || undefined, handlers: messageHandlers, @@ -1134,6 +1163,23 @@ const Create = () => { }; }, [isChatResizing, chatWidth]); + const handleTakeOverSession = useCallback(() => { + setProjectLockedInfo(null); + // Reconnect with force_claim in the INIT payload to take over the session + connectWithExtra({ force_claim: true }).catch((e) => + console.error("Force-claim reconnect failed:", e) + ); + }, [connectWithExtra]); + + const handleGoBackToProjects = useCallback(() => { + navigate("/"); + }, [navigate]); + + const handleSessionClaimedDismiss = useCallback(() => { + setSessionClaimed(null); + navigate("/"); + }, [navigate]); + const handleSendMessage = () => { const text = inputValue.trim(); const imgs = pendingImages; @@ -2355,11 +2401,33 @@ const Create = () => { ) : null} -
+
+ {projectInfo && ( + + )}
+ {showMembers && projectInfo && ( + + )} +
{
+ + {/* Project locking modals */} + {projectLockedInfo && ( + + )} + + {sessionClaimed && ( + + )} ); }; diff --git a/frontend/src/services/websocketBus.ts b/frontend/src/services/websocketBus.ts index 881c019..6e9535b 100644 --- a/frontend/src/services/websocketBus.ts +++ b/frontend/src/services/websocketBus.ts @@ -9,6 +9,7 @@ export interface WebSocketBusConfig { url: string; messageBus: MessageBus; sessionId?: string; + initExtra?: Record; } export class WebSocketBus { @@ -83,11 +84,14 @@ export class WebSocketBus { this.reconnectAttempts = 0; this.config.messageBus.setConnected(true); - const initData = this.config.sessionId - ? { session_id: this.config.sessionId } - : {}; + const initData = { + ...(this.config.sessionId ? { session_id: this.config.sessionId } : {}), + ...this.config.initExtra, + }; console.log("Sending INIT message with session_id:", this.config.sessionId); this.sendMessage(createMessage(MessageType.INIT, initData)); + // Clear initExtra after first use so reconnects don't replay it + this.config.initExtra = undefined; if (!settled) { settled = true; @@ -110,6 +114,7 @@ export class WebSocketBus { if ( event.code !== 1000 && + event.code !== 1008 && // Don't auto-reconnect on policy violations (e.g. project_locked) this.reconnectAttempts < this.maxReconnectAttempts ) { this.scheduleReconnect(); @@ -238,7 +243,8 @@ export class WebSocketBus { export const createWebSocketBus = ( url: string, messageBus: MessageBus, - sessionId?: string + sessionId?: string, + initExtra?: Record ): WebSocketBus => { - return new WebSocketBus({ url, messageBus, sessionId }); + return new WebSocketBus({ url, messageBus, sessionId, initExtra }); }; diff --git a/frontend/src/types/messages.ts b/frontend/src/types/messages.ts index ed2cf92..9d13b52 100644 --- a/frontend/src/types/messages.ts +++ b/frontend/src/types/messages.ts @@ -13,6 +13,7 @@ export enum MessageType { HITL_REQUEST = "hitl_request", HITL_RESPONSE = "hitl_response", RUNTIME_ERROR = "runtime_error", + SESSION_CLAIMED = "session_claimed", } export enum Sender { diff --git a/src/agent_core.py b/src/agent_core.py index 1c6c103..7abaab4 100644 --- a/src/agent_core.py +++ b/src/agent_core.py @@ -244,6 +244,7 @@ class MessageType(Enum): HITL_REQUEST = "hitl_request" HITL_RESPONSE = "hitl_response" RUNTIME_ERROR = "runtime_error" + SESSION_CLAIMED = "session_claimed" ERROR = "error" PING = "ping" diff --git a/src/projects/store.py b/src/projects/store.py index a9b12df..3cc9f89 100644 --- a/src/projects/store.py +++ b/src/projects/store.py @@ -5,6 +5,7 @@ import threading import uuid from dataclasses import dataclass +from datetime import UTC, datetime from typing import TYPE_CHECKING, Any if TYPE_CHECKING: # pragma: no cover @@ -62,6 +63,9 @@ def ensure_projects_schema(client: HasuraClient) -> None: gitlab_project_id bigint NULL, gitlab_path text NULL, gitlab_web_url text NULL, + locked_by_sub text NULL, + locked_by_email text NULL, + locked_at timestamptz NULL, created_at timestamptz NOT NULL DEFAULT now(), updated_at timestamptz NOT NULL DEFAULT now(), deleted_at timestamptz NULL @@ -76,6 +80,32 @@ def ensure_projects_schema(client: HasuraClient) -> None: ADD COLUMN IF NOT EXISTS gitlab_path text NULL; ALTER TABLE amicable_meta.projects ADD COLUMN IF NOT EXISTS gitlab_web_url text NULL; + ALTER TABLE amicable_meta.projects + ADD COLUMN IF NOT EXISTS locked_by_sub text NULL; + ALTER TABLE amicable_meta.projects + ADD COLUMN IF NOT EXISTS locked_by_email text NULL; + ALTER TABLE amicable_meta.projects + ADD COLUMN IF NOT EXISTS locked_at timestamptz NULL; + + CREATE TABLE IF NOT EXISTS amicable_meta.project_members ( + project_id text NOT NULL REFERENCES amicable_meta.projects(project_id) ON DELETE CASCADE, + user_sub text NULL, + user_email text NOT NULL, + added_at timestamptz NOT NULL DEFAULT now(), + added_by_sub text NULL, + PRIMARY KEY (project_id, user_email) + ); + CREATE INDEX IF NOT EXISTS idx_project_members_user + ON amicable_meta.project_members(user_sub); + CREATE INDEX IF NOT EXISTS idx_project_members_email + ON amicable_meta.project_members(user_email); + + -- Migration: add existing project owners as members + INSERT INTO amicable_meta.project_members (project_id, user_sub, user_email, added_at) + SELECT p.project_id, p.owner_sub, LOWER(p.owner_email), p.created_at + FROM amicable_meta.projects p + WHERE p.deleted_at IS NULL + ON CONFLICT DO NOTHING; """.strip() ) _schema_ready = True @@ -103,6 +133,23 @@ class Project: updated_at: str | None = None +@dataclass(frozen=True) +class ProjectMember: + project_id: str + user_sub: str | None # None if invited by email but not yet logged in + user_email: str + added_at: str | None = None + added_by_sub: str | None = None + + +@dataclass(frozen=True) +class ProjectLock: + project_id: str + locked_by_sub: str + locked_by_email: str + locked_at: str + + def _tuples_to_dicts(res: dict[str, Any]) -> list[dict[str, Any]]: rows = res.get("result") if not isinstance(rows, list) or len(rows) < 2: @@ -232,7 +279,9 @@ def get_project_by_id( p = _get_project_by_id_any_owner(client, project_id=project_id) if not p: return None - if p.owner_sub != owner.sub: + if not is_project_member( + client, project_id=project_id, user_sub=owner.sub, user_email=owner.email + ): return None return p @@ -256,10 +305,13 @@ def get_project_by_slug( if not rows: return None r = rows[0] - if str(r.get("owner_sub")) != owner.sub: + project_id = str(r["project_id"]) + if not is_project_member( + client, project_id=project_id, user_sub=owner.sub, user_email=owner.email + ): return None return Project( - project_id=str(r["project_id"]), + project_id=project_id, owner_sub=str(r["owner_sub"]), owner_email=str(r["owner_email"]), name=str(r["name"]), @@ -357,12 +409,17 @@ def list_projects(client: HasuraClient, *, owner: ProjectOwner) -> list[Project] ensure_projects_schema(client) res = client.run_sql( f""" - SELECT project_id, owner_sub, owner_email, name, slug, sandbox_id, template_id, - gitlab_project_id, gitlab_path, gitlab_web_url, - created_at, updated_at - FROM amicable_meta.projects - WHERE owner_sub = {_sql_str(owner.sub)} AND deleted_at IS NULL - ORDER BY updated_at DESC; + SELECT p.project_id, p.owner_sub, p.owner_email, p.name, p.slug, p.sandbox_id, p.template_id, + p.gitlab_project_id, p.gitlab_path, p.gitlab_web_url, + p.created_at, p.updated_at + FROM amicable_meta.projects p + WHERE p.deleted_at IS NULL + AND EXISTS ( + SELECT 1 FROM amicable_meta.project_members pm + WHERE pm.project_id = p.project_id + AND (pm.user_sub = {_sql_str(owner.sub)} OR pm.user_email = {_sql_str(owner.email.lower())}) + ) + ORDER BY p.updated_at DESC; """.strip(), read_only=True, ) @@ -409,11 +466,17 @@ def set_project_slug( new_slug: str, ) -> Project: ensure_projects_schema(client) + # Check membership before updating client.run_sql( f""" - UPDATE amicable_meta.projects + UPDATE amicable_meta.projects p SET slug = {_sql_str(new_slug)}, updated_at = now() - WHERE project_id = {_sql_str(project_id)} AND owner_sub = {_sql_str(owner.sub)} AND deleted_at IS NULL; + WHERE p.project_id = {_sql_str(project_id)} AND p.deleted_at IS NULL + AND EXISTS ( + SELECT 1 FROM amicable_meta.project_members pm + WHERE pm.project_id = p.project_id + AND (pm.user_sub = {_sql_str(owner.sub)} OR pm.user_email = {_sql_str(owner.email.lower())}) + ); """.strip() ) p = get_project_by_id(client, owner=owner, project_id=project_id) @@ -432,14 +495,20 @@ def set_gitlab_metadata( gitlab_web_url: str | None, ) -> Project: ensure_projects_schema(client) + # Check membership before updating client.run_sql( f""" - UPDATE amicable_meta.projects + UPDATE amicable_meta.projects p SET gitlab_project_id = {str(int(gitlab_project_id)) if gitlab_project_id is not None else "NULL"}, gitlab_path = {_sql_str(gitlab_path) if gitlab_path is not None else "NULL"}, gitlab_web_url = {_sql_str(gitlab_web_url) if gitlab_web_url is not None else "NULL"}, updated_at = now() - WHERE project_id = {_sql_str(project_id)} AND owner_sub = {_sql_str(owner.sub)} AND deleted_at IS NULL; + WHERE p.project_id = {_sql_str(project_id)} AND p.deleted_at IS NULL + AND EXISTS ( + SELECT 1 FROM amicable_meta.project_members pm + WHERE pm.project_id = p.project_id + AND (pm.user_sub = {_sql_str(owner.sub)} OR pm.user_email = {_sql_str(owner.email.lower())}) + ); """.strip() ) p = get_project_by_id(client, owner=owner, project_id=project_id) @@ -498,6 +567,14 @@ def create_project( ) p = _get_project_by_id_any_owner(client, project_id=project_id) if p and p.owner_sub == owner.sub: + # Add creator as first member + add_project_member( + client, + project_id=project_id, + user_sub=owner.sub, + user_email=owner.email, + added_by_sub=None, + ) return p raise RuntimeError("failed to allocate unique project slug") @@ -510,8 +587,11 @@ def ensure_project_for_id( ensure_projects_schema(client) existing = _get_project_by_id_any_owner(client, project_id=project_id) if existing: - if existing.owner_sub != owner.sub: - raise PermissionError("project belongs to a different user") + # Check membership instead of ownership + if not is_project_member( + client, project_id=project_id, user_sub=owner.sub, user_email=owner.email + ): + raise PermissionError("not a project member") return existing # If a project row exists but was soft-deleted, resurrect it instead of failing @@ -531,6 +611,14 @@ def ensure_project_for_id( ) revived = _get_project_by_id_any_owner(client, project_id=project_id) if revived: + # Ensure member record exists after resurrection + add_project_member( + client, + project_id=project_id, + user_sub=owner.sub, + user_email=owner.email, + added_by_sub=None, + ) return revived short = project_id.replace("-", "")[:8] @@ -552,6 +640,14 @@ def ensure_project_for_id( ) created = _get_project_by_id_any_owner(client, project_id=project_id) if created and created.owner_sub == owner.sub: + # Add creator as first member + add_project_member( + client, + project_id=project_id, + user_sub=owner.sub, + user_email=owner.email, + added_by_sub=None, + ) return created raise RuntimeError("failed to auto-create project") @@ -563,7 +659,7 @@ def rename_project( ensure_projects_schema(client) base = slugify(new_name) - # Ensure project exists and is owned. + # Ensure project exists and user is a member (membership checked by get_project_by_id) existing = get_project_by_id(client, owner=owner, project_id=project_id) if not existing: raise PermissionError("project not found") @@ -579,7 +675,12 @@ def rename_project( f""" UPDATE amicable_meta.projects SET name = {_sql_str(new_name)}, slug = {_sql_str(slug)}, updated_at = now() - WHERE project_id = {_sql_str(project_id)} AND owner_sub = {_sql_str(owner.sub)} AND deleted_at IS NULL; + WHERE project_id = {_sql_str(project_id)} AND deleted_at IS NULL + AND EXISTS ( + SELECT 1 FROM amicable_meta.project_members pm + WHERE pm.project_id = {_sql_str(project_id)} + AND (pm.user_sub = {_sql_str(owner.sub)} OR pm.user_email = {_sql_str(owner.email.lower())}) + ); """.strip() ) updated = get_project_by_id(client, owner=owner, project_id=project_id) @@ -593,12 +694,17 @@ def mark_project_deleted( client: HasuraClient, *, owner: ProjectOwner, project_id: str ) -> None: ensure_projects_schema(client) - # Mark deleted first so it disappears from lists immediately. + # Verify user is a member before allowing delete + if not is_project_member( + client, project_id=project_id, user_sub=owner.sub, user_email=owner.email + ): + raise PermissionError("not a member") + # Mark deleted so it disappears from lists immediately. client.run_sql( f""" UPDATE amicable_meta.projects SET deleted_at = now(), updated_at = now() - WHERE project_id = {_sql_str(project_id)} AND owner_sub = {_sql_str(owner.sub)} AND deleted_at IS NULL; + WHERE project_id = {_sql_str(project_id)} AND deleted_at IS NULL; """.strip() ) @@ -607,9 +713,257 @@ def hard_delete_project_row( client: HasuraClient, *, owner: ProjectOwner, project_id: str ) -> None: ensure_projects_schema(client) + # Verify user is a member before allowing hard delete + if not is_project_member( + client, project_id=project_id, user_sub=owner.sub, user_email=owner.email + ): + raise PermissionError("not a member") client.run_sql( f""" DELETE FROM amicable_meta.projects - WHERE project_id = {_sql_str(project_id)} AND owner_sub = {_sql_str(owner.sub)}; + WHERE project_id = {_sql_str(project_id)}; + """.strip() + ) + + +# --------------------------------------------------------------------------- +# Project Members +# --------------------------------------------------------------------------- + + +def _get_member_by_email( + client: HasuraClient, *, project_id: str, user_email: str +) -> ProjectMember | None: + res = client.run_sql( + f""" + SELECT project_id, user_sub, user_email, added_at, added_by_sub + FROM amicable_meta.project_members + WHERE project_id = {_sql_str(project_id)} AND user_email = {_sql_str(user_email.lower())} + LIMIT 1; + """.strip(), + read_only=True, + ) + rows = _tuples_to_dicts(res) + if not rows: + return None + r = rows[0] + return ProjectMember( + project_id=str(r["project_id"]), + user_sub=str(r["user_sub"]) if r.get("user_sub") else None, + user_email=str(r["user_email"]), + added_at=str(r.get("added_at")) if r.get("added_at") else None, + added_by_sub=str(r.get("added_by_sub")) if r.get("added_by_sub") else None, + ) + + +def add_project_member( + client: HasuraClient, + *, + project_id: str, + user_email: str, + user_sub: str | None = None, + added_by_sub: str | None = None, +) -> ProjectMember: + """Add a member to a project. If user_sub is None, they'll be matched on first login.""" + ensure_projects_schema(client) + user_email = user_email.strip().lower() + + # Check if already a member by email — skip only if no new user_sub to backfill + existing = _get_member_by_email(client, project_id=project_id, user_email=user_email) + if existing and (existing.user_sub or not user_sub): + return existing + + sub_sql = _sql_str(user_sub) if user_sub else "NULL" + added_by_sql = _sql_str(added_by_sub) if added_by_sub else "NULL" + + client.run_sql( + f""" + INSERT INTO amicable_meta.project_members (project_id, user_sub, user_email, added_by_sub) + VALUES ({_sql_str(project_id)}, {sub_sql}, {_sql_str(user_email)}, {added_by_sql}) + ON CONFLICT (project_id, user_email) DO UPDATE SET user_sub = COALESCE(EXCLUDED.user_sub, amicable_meta.project_members.user_sub); + """.strip() + ) + # Re-fetch to get DB-populated fields (added_at, resolved user_sub from upsert) + member = _get_member_by_email(client, project_id=project_id, user_email=user_email) + if member: + return member + return ProjectMember( + project_id=project_id, + user_sub=user_sub, + user_email=user_email, + added_by_sub=added_by_sub, + added_at=datetime.now(tz=UTC).isoformat(), + ) + + +def list_project_members(client: HasuraClient, *, project_id: str) -> list[ProjectMember]: + """List all members of a project.""" + ensure_projects_schema(client) + res = client.run_sql( + f""" + SELECT project_id, user_sub, user_email, added_at, added_by_sub + FROM amicable_meta.project_members + WHERE project_id = {_sql_str(project_id)} + ORDER BY added_at ASC; + """.strip(), + read_only=True, + ) + out: list[ProjectMember] = [] + for r in _tuples_to_dicts(res): + out.append( + ProjectMember( + project_id=str(r["project_id"]), + user_sub=str(r["user_sub"]) if r.get("user_sub") else None, + user_email=str(r["user_email"]), + added_at=str(r.get("added_at")) if r.get("added_at") else None, + added_by_sub=str(r.get("added_by_sub")) if r.get("added_by_sub") else None, + ) + ) + return out + + +def remove_project_member( + client: HasuraClient, *, project_id: str, user_sub: str +) -> bool: + """Remove a member from a project. Returns False if they were the last member or target not found.""" + ensure_projects_schema(client) + members = list_project_members(client, project_id=project_id) + if len(members) <= 1: + return False + if not any(m.user_sub == user_sub for m in members): + return False + client.run_sql( + f""" + DELETE FROM amicable_meta.project_members + WHERE project_id = {_sql_str(project_id)} AND user_sub = {_sql_str(user_sub)}; + """.strip() + ) + return True + + +def remove_project_member_by_email( + client: HasuraClient, *, project_id: str, user_email: str +) -> bool: + """Remove a pending member by email. Returns False if they were the last member or target not found.""" + ensure_projects_schema(client) + user_email = user_email.strip().lower() + members = list_project_members(client, project_id=project_id) + if len(members) <= 1: + return False + if not any(m.user_email == user_email for m in members): + return False + client.run_sql( + f""" + DELETE FROM amicable_meta.project_members + WHERE project_id = {_sql_str(project_id)} AND user_email = {_sql_str(user_email)}; + """.strip() + ) + return True + + +def is_project_member( + client: HasuraClient, *, project_id: str, user_sub: str, user_email: str +) -> bool: + """Check if a user is a member of a project (by sub or email).""" + ensure_projects_schema(client) + res = client.run_sql( + f""" + SELECT 1 FROM amicable_meta.project_members + WHERE project_id = {_sql_str(project_id)} + AND (user_sub = {_sql_str(user_sub)} OR user_email = {_sql_str(user_email.lower())}) + LIMIT 1; + """.strip(), + read_only=True, + ) + rows = _tuples_to_dicts(res) + return bool(rows) + + +# --------------------------------------------------------------------------- +# Session Locking +# --------------------------------------------------------------------------- + + +def get_project_lock(client: HasuraClient, *, project_id: str) -> ProjectLock | None: + """Get current lock info for a project.""" + ensure_projects_schema(client) + res = client.run_sql( + f""" + SELECT locked_by_sub, locked_by_email, locked_at + FROM amicable_meta.projects + WHERE project_id = {_sql_str(project_id)} AND deleted_at IS NULL AND locked_by_sub IS NOT NULL + LIMIT 1; + """.strip(), + read_only=True, + ) + rows = _tuples_to_dicts(res) + if not rows or not rows[0].get("locked_by_sub"): + return None + r = rows[0] + return ProjectLock( + project_id=project_id, + locked_by_sub=str(r["locked_by_sub"]), + locked_by_email=str(r.get("locked_by_email") or ""), + locked_at=str(r.get("locked_at") or ""), + ) + + +def acquire_project_lock( + client: HasuraClient, + *, + project_id: str, + user_sub: str, + user_email: str, + force: bool = False, +) -> ProjectLock | None: + """Try to acquire lock atomically. Returns None if locked by someone else (unless force=True).""" + ensure_projects_schema(client) + + # Atomic conditional update - only acquire if unlocked, held by self, or force + if force: + # Force always succeeds + client.run_sql( + f""" + UPDATE amicable_meta.projects + SET locked_by_sub = {_sql_str(user_sub)}, locked_by_email = {_sql_str(user_email)}, + locked_at = now(), updated_at = now() + WHERE project_id = {_sql_str(project_id)} AND deleted_at IS NULL; + """.strip() + ) + else: + # Only acquire if unlocked or already held by this user + client.run_sql( + f""" + UPDATE amicable_meta.projects + SET locked_by_sub = {_sql_str(user_sub)}, locked_by_email = {_sql_str(user_email)}, + locked_at = now(), updated_at = now() + WHERE project_id = {_sql_str(project_id)} + AND deleted_at IS NULL + AND (locked_by_sub IS NULL OR locked_by_sub = {_sql_str(user_sub)}); + """.strip() + ) + + # Check if we actually got the lock + lock = get_project_lock(client, project_id=project_id) + if lock and lock.locked_by_sub == user_sub: + return ProjectLock( + project_id=project_id, + locked_by_sub=user_sub, + locked_by_email=user_email, + locked_at=lock.locked_at, + ) + return None + + +def release_project_lock( + client: HasuraClient, *, project_id: str, user_sub: str +) -> None: + """Release lock if held by user_sub.""" + ensure_projects_schema(client) + client.run_sql( + f""" + UPDATE amicable_meta.projects + SET locked_by_sub = NULL, locked_by_email = NULL, locked_at = NULL, updated_at = now() + WHERE project_id = {_sql_str(project_id)} AND locked_by_sub = {_sql_str(user_sub)} AND deleted_at IS NULL; """.strip() ) diff --git a/src/runtimes/ws_server.py b/src/runtimes/ws_server.py index 8f72bf0..4bb9f64 100644 --- a/src/runtimes/ws_server.py +++ b/src/runtimes/ws_server.py @@ -46,6 +46,11 @@ # Best-effort in-memory limiter for runtime error auto-heal. _runtime_autoheal_state_by_project: dict[str, Any] = {} +# Track active WebSocket connections for session locking. +# Maps WebSocket -> (project_id, user_sub) for lock release on disconnect. +_ws_session_map: dict[WebSocket, tuple[str, str]] = {} +_background_tasks: set[asyncio.Task[Any]] = set() + _naming_llm: Any = None @@ -1116,6 +1121,166 @@ def _cleanup() -> None: return JSONResponse({"status": "deleting"}, status_code=202, background=bg) +# --------------------------------------------------------------------------- +# Project Members API +# --------------------------------------------------------------------------- + + +@app.get("/api/projects/{project_id}/members") +async def api_list_project_members(project_id: str, request: Request) -> JSONResponse: + """List all members of a project.""" + _require_hasura() + try: + sub, email = _get_owner_from_request(request) + except PermissionError: + return JSONResponse({"error": "not_authenticated"}, status_code=401) + + from src.db.provisioning import hasura_client_from_env + from src.projects.store import ProjectOwner, get_project_by_id, list_project_members + + def _list_sync(): + client = hasura_client_from_env() + owner = ProjectOwner(sub=sub, email=email) + project = get_project_by_id(client, owner=owner, project_id=project_id) + if not project: + return None + return list_project_members(client, project_id=project_id) + + members = await asyncio.to_thread(_list_sync) + if members is None: + return JSONResponse({"error": "not_found"}, status_code=404) + + return JSONResponse({ + "members": [ + { + "user_sub": m.user_sub, + "user_email": m.user_email, + "added_at": m.added_at, + } + for m in members + ] + }) + + +@app.post("/api/projects/{project_id}/members") +async def api_add_project_member(project_id: str, request: Request) -> JSONResponse: + """Add a member to a project by email.""" + _require_hasura() + try: + sub, email = _get_owner_from_request(request) + except PermissionError: + return JSONResponse({"error": "not_authenticated"}, status_code=401) + + body: Any + try: + body = await request.json() + except Exception: + body = {} + + new_email = str(body.get("email") or "").strip().lower() + if not new_email or "@" not in new_email: + return JSONResponse({"error": "invalid_email"}, status_code=400) + + from src.db.provisioning import hasura_client_from_env + from src.projects.store import ProjectOwner, add_project_member, get_project_by_id + + def _add_sync(): + client = hasura_client_from_env() + owner = ProjectOwner(sub=sub, email=email) + project = get_project_by_id(client, owner=owner, project_id=project_id) + if not project: + return None + return add_project_member( + client, + project_id=project_id, + user_email=new_email, + added_by_sub=sub, + ) + + member = await asyncio.to_thread(_add_sync) + if member is None: + return JSONResponse({"error": "not_found"}, status_code=404) + + return JSONResponse({ + "user_email": member.user_email, + "added_at": member.added_at, + }, status_code=201) + + +@app.delete("/api/projects/{project_id}/members/{user_sub}") +async def api_remove_project_member( + project_id: str, user_sub: str, request: Request +) -> JSONResponse: + """Remove a member from a project by user_sub.""" + _require_hasura() + try: + sub, email = _get_owner_from_request(request) + except PermissionError: + return JSONResponse({"error": "not_authenticated"}, status_code=401) + + from src.db.provisioning import hasura_client_from_env + from src.projects.store import ( + ProjectOwner, + get_project_by_id, + remove_project_member, + ) + + def _remove_sync(): + client = hasura_client_from_env() + owner = ProjectOwner(sub=sub, email=email) + project = get_project_by_id(client, owner=owner, project_id=project_id) + if not project: + return "not_found" + success = remove_project_member(client, project_id=project_id, user_sub=user_sub) + return "ok" if success else "remove_failed" + + result = await asyncio.to_thread(_remove_sync) + if result == "not_found": + return JSONResponse({"error": "not_found"}, status_code=404) + if result == "remove_failed": + return JSONResponse({"error": "cannot_remove_member"}, status_code=400) + + return JSONResponse({"ok": True}) + + +@app.delete("/api/projects/{project_id}/members/by-email/{user_email:path}") +async def api_remove_project_member_by_email( + project_id: str, user_email: str, request: Request +) -> JSONResponse: + """Remove a pending member from a project by email (for users who haven't logged in yet).""" + _require_hasura() + try: + sub, email = _get_owner_from_request(request) + except PermissionError: + return JSONResponse({"error": "not_authenticated"}, status_code=401) + + from src.db.provisioning import hasura_client_from_env + from src.projects.store import ( + ProjectOwner, + get_project_by_id, + remove_project_member_by_email, + ) + + def _remove_sync(): + client = hasura_client_from_env() + owner = ProjectOwner(sub=sub, email=email) + project = get_project_by_id(client, owner=owner, project_id=project_id) + if not project: + return "not_found" + success = remove_project_member_by_email(client, project_id=project_id, user_email=user_email) + # The underlying store only tells us whether a removal happened, not why it failed. + # Do not assume every failure is due to "last member"; report a generic failure instead. + return "ok" if success else "cannot_remove_member" + + result = await asyncio.to_thread(_remove_sync) + if result == "not_found": + return JSONResponse({"error": "not_found"}, status_code=404) + if result == "cannot_remove_member": + return JSONResponse({"error": "cannot_remove_member"}, status_code=400) + + return JSONResponse({"ok": True}) + + def _ensure_project_access(request: Request, *, project_id: str): """Return the project if the request is allowed to access it, else raise PermissionError.""" _require_hasura() @@ -2222,7 +2387,8 @@ async def _handle_ws(ws: WebSocket) -> None: _agent = Agent() agent = _agent - while True: + try: + while True: try: raw = await ws.receive_text() except WebSocketDisconnect: @@ -2306,6 +2472,99 @@ def _load_sync( ) await ws.close(code=1011) return + + # Session locking: check/acquire lock + force_claim = bool(data.get("force_claim", data.get("force", False))) + from src.projects.store import ( + acquire_project_lock, + get_project_lock, + ) + + def _check_and_acquire_lock( + _sid: str = str(session_id), + _sub: str = sub, + _email: str = email, + _force: bool = force_claim, + ) -> tuple[str, dict[str, Any] | None]: + client = hasura_client_from_env() + current_lock = get_project_lock(client, project_id=_sid) + if current_lock and current_lock.locked_by_sub != _sub and not _force: + return "locked", { + "locked_by_email": current_lock.locked_by_email, + "locked_at": current_lock.locked_at, + } + lock = acquire_project_lock( + client, + project_id=_sid, + user_sub=_sub, + user_email=_email, + force=_force, + ) + if not lock: + # Race: someone else acquired it between check and acquire + current = get_project_lock(client, project_id=_sid) + return "locked", { + "locked_by_email": current.locked_by_email if current else "unknown", + "locked_at": current.locked_at if current else "", + } + return "ok", None + + # If force-claiming, notify the previous session holder first + if force_claim: + for other_ws, (other_sess, other_sub) in list(_ws_session_map.items()): + if other_sess == str(session_id) and other_sub != sub: + try: + await other_ws.send_json( + Message.new( + MessageType.SESSION_CLAIMED, + {"claimed_by_email": email}, + session_id=str(session_id), + ).to_dict() + ) + await other_ws.close(code=1000) + except Exception: + logger.debug( + "Failed to notify/close previous session websocket " + "(session_id=%s, previous_sub=%s)", + session_id, + other_sub, + exc_info=True, + ) + _ws_session_map.pop(other_ws, None) + + try: + lock_status, lock_info = await asyncio.to_thread(_check_and_acquire_lock) + except Exception as exc: + logger.exception("Lock acquisition failed for session %s: %s", session_id, exc) + await ws.send_json( + Message.new( + MessageType.ERROR, + {"code": "lock_error", "error": "Failed to check session lock"}, + session_id=session_id, + ).to_dict() + ) + await ws.close(code=1011) + return + if lock_status == "locked": + await ws.send_json( + Message.new( + MessageType.ERROR, + { + "code": "project_locked", + "locked_by": { + "email": lock_info["locked_by_email"] if lock_info else "unknown", + "at": lock_info["locked_at"] if lock_info else "", + }, + }, + session_id=session_id, + ).to_dict() + ) + await ws.close(code=1008) + return + + # Register this websocket for lock tracking + _ws_session_map[ws] = (str(session_id), sub) + template_id = ( getattr(project, "template_id", None) if project is not None else None ) @@ -2680,6 +2939,24 @@ def _load_sync( ).to_dict() ) continue + finally: + conn_info = _ws_session_map.pop(ws, None) + if conn_info: + sess_id, u_sub = conn_info + + def _release_lock_on_disconnect() -> None: + try: + from src.db.provisioning import hasura_client_from_env + from src.projects.store import release_project_lock + + client = hasura_client_from_env() + release_project_lock(client, project_id=sess_id, user_sub=u_sub) + except Exception as exc: + logger.warning("Failed to release lock for session %s: %s", sess_id, exc) + + task = asyncio.create_task(asyncio.to_thread(_release_lock_on_disconnect)) + _background_tasks.add(task) + task.add_done_callback(_background_tasks.discard) @app.websocket("/") diff --git a/tests/test_projects_store.py b/tests/test_projects_store.py index 5918007..409df12 100644 --- a/tests/test_projects_store.py +++ b/tests/test_projects_store.py @@ -21,6 +21,8 @@ class FakeHasuraClient: def __init__(self) -> None: self.projects: dict[str, dict] = {} + self.members: dict[tuple[str, str], dict] = {} # (project_id, user_key) -> member + self.schema_created_tables: dict[str, bool] = {} class _Cfg: source_name = "default" @@ -32,10 +34,86 @@ def run_sql(self, sql: str, *, read_only: bool = False): sql = sql.strip() sql_l = sql.lower() - # Schema creation: ignore. - if sql.lower().startswith("create schema") or sql.lower().startswith( - "create table" + # Schema creation: track tables and ignore. + # Handle multi-statement SQL for schema migration + if "create schema" in sql_l or "create table" in sql_l or "alter table" in sql_l or "create index" in sql_l: + if "create table" in sql_l and "project_members" in sql_l: + self.schema_created_tables["project_members"] = True + if "create table" in sql_l and "amicable_meta.projects" in sql_l: + self.schema_created_tables["projects"] = True + return {"result_type": "CommandOk", "result": []} + + # --------------------------------------------------------------- + # Session Locking (must come before generic project SELECT) + # --------------------------------------------------------------- + + # SELECT lock info (get_project_lock) - check for "locked_by_sub is not null" + if ( + "locked_by_sub" in sql_l + and "locked_by_sub is not null" in sql_l + and "from amicable_meta.projects" in sql_l + and sql_l.startswith("select") + ): + pid_match = re.search(r"where project_id\s*=\s*'([^']+)'", sql, flags=re.I) + if pid_match: + row = self.projects.get(pid_match.group(1)) + if row and not row.get("deleted_at") and row.get("locked_by_sub"): + return { + "result_type": "TuplesOk", + "result": [ + ["locked_by_sub", "locked_by_email", "locked_at"], + [row["locked_by_sub"], row.get("locked_by_email", ""), row.get("locked_at")], + ], + } + return { + "result_type": "TuplesOk", + "result": [["locked_by_sub", "locked_by_email", "locked_at"]], + } + + # UPDATE lock (acquire/release) - check for "set locked_by_sub" + if ( + sql_l.startswith("update amicable_meta.projects") + and "set locked_by_sub" in sql_l ): + pid_match = re.search(r"where project_id\s*=\s*'([^']+)'", sql, flags=re.I) + if pid_match: + row = self.projects.get(pid_match.group(1)) + if row and not row.get("deleted_at"): + if "locked_by_sub = null" in sql_l.lower(): + # Release - check user matches + sub_match = re.search( + r"and locked_by_sub\s*=\s*'([^']+)'", sql, flags=re.I + ) + if sub_match and row.get("locked_by_sub") == sub_match.group(1): + row["locked_by_sub"] = None + row["locked_by_email"] = None + row["locked_at"] = None + else: + # Acquire - extract the new user_sub and email + sub_match = re.search( + r"set locked_by_sub\s*=\s*'([^']+)'", sql, flags=re.I + ) + email_match = re.search( + r"locked_by_email\s*=\s*'([^']+)'", sql, flags=re.I + ) + if sub_match: + new_sub = sub_match.group(1) + new_email = email_match.group(1) if email_match else "" + # Check for atomic conditional update pattern: + # (locked_by_sub IS NULL OR locked_by_sub = 'user') + if "locked_by_sub is null or locked_by_sub" in sql_l: + # Conditional acquire - only if unlocked or held by same user + current_holder = row.get("locked_by_sub") + if current_holder is None or current_holder == new_sub: + row["locked_by_sub"] = new_sub + row["locked_by_email"] = new_email + row["locked_at"] = "now" + # else: don't update (someone else holds it) + else: + # Unconditional (force) acquire + row["locked_by_sub"] = new_sub + row["locked_by_email"] = new_email + row["locked_at"] = "now" return {"result_type": "CommandOk", "result": []} # SELECT by project_id. @@ -115,7 +193,60 @@ def run_sql(self, sql: str, *, read_only: bool = False): ] return {"result_type": "TuplesOk", "result": [header, data]} - # List by owner_sub. + # List projects with EXISTS on members (new membership-based query) + if "exists" in sql_l and "project_members" in sql_l and "order by" in sql_l: + sub_match = re.search(r"pm\.user_sub\s*=\s*'([^']+)'", sql, flags=re.I) + email_match = re.search(r"pm\.user_email\s*=\s*'([^']+)'", sql, flags=re.I) + sub = sub_match.group(1) if sub_match else None + email = email_match.group(1).lower() if email_match else None + + # Find projects where user is a member + member_project_ids = set() + for m in self.members.values(): + if m.get("user_sub") == sub or m.get("user_email") == email: + member_project_ids.add(m["project_id"]) + + rows = [ + p + for p in self.projects.values() + if p["project_id"] in member_project_ids and not p.get("deleted_at") + ] + + header = [ + "project_id", + "owner_sub", + "owner_email", + "name", + "slug", + "sandbox_id", + "template_id", + "gitlab_project_id", + "gitlab_path", + "gitlab_web_url", + "created_at", + "updated_at", + ] + out = [header] + for r in rows: + out.append( + [ + r["project_id"], + r["owner_sub"], + r["owner_email"], + r["name"], + r["slug"], + r.get("sandbox_id"), + r.get("template_id"), + r.get("gitlab_project_id"), + r.get("gitlab_path"), + r.get("gitlab_web_url"), + r.get("created_at"), + r.get("updated_at"), + ] + ) + return {"result_type": "TuplesOk", "result": out} + + # List by owner_sub (legacy - but still needed for backward compatibility in tests). m = re.search(r"where owner_sub\s*=\s*'([^']+)'", sql, flags=re.I) if m and "order by updated_at" in sql_l: sub = m.group(1) @@ -184,16 +315,15 @@ def run_sql(self, sql: str, *, read_only: bool = False): } return {"result_type": "CommandOk", "result": []} - # UPDATE rename. + # UPDATE rename (membership is checked before SQL call, so no owner_sub in WHERE). if sql_l.startswith("update amicable_meta.projects") and "set name" in sql_l: pid = re.search(r"where project_id\s*=\s*'([^']+)'", sql, flags=re.I).group( 1 ) # type: ignore[union-attr] - sub = re.search(r"and owner_sub\s*=\s*'([^']+)'", sql, flags=re.I).group(1) # type: ignore[union-attr] name = re.search(r"set name\s*=\s*'([^']*)'", sql, flags=re.I).group(1) # type: ignore[union-attr] slug = re.search(r"slug\s*=\s*'([^']*)'", sql, flags=re.I).group(1) # type: ignore[union-attr] row = self.projects.get(pid) - if row and row["owner_sub"] == sub and not row.get("deleted_at"): + if row and not row.get("deleted_at"): # enforce slug uniqueness if any( ( @@ -209,7 +339,7 @@ def run_sql(self, sql: str, *, read_only: bool = False): row["updated_at"] = "t1" return {"result_type": "CommandOk", "result": []} - # UPDATE mark deleted. + # UPDATE mark deleted (membership is checked before SQL call, so no owner_sub in WHERE). if ( sql_l.startswith("update amicable_meta.projects") and "set deleted_at" in sql_l @@ -217,24 +347,136 @@ def run_sql(self, sql: str, *, read_only: bool = False): pid = re.search(r"where project_id\s*=\s*'([^']+)'", sql, flags=re.I).group( 1 ) # type: ignore[union-attr] - sub = re.search(r"and owner_sub\s*=\s*'([^']+)'", sql, flags=re.I).group(1) # type: ignore[union-attr] row = self.projects.get(pid) - if row and row["owner_sub"] == sub and not row.get("deleted_at"): + if row and not row.get("deleted_at"): row["deleted_at"] = "t_del" row["updated_at"] = "t_del" return {"result_type": "CommandOk", "result": []} - # DELETE. + # DELETE project (membership is checked before SQL call, so no owner_sub in WHERE). if sql_l.startswith("delete from amicable_meta.projects"): pid = re.search(r"where project_id\s*=\s*'([^']+)'", sql, flags=re.I).group( 1 ) # type: ignore[union-attr] - sub = re.search(r"and owner_sub\s*=\s*'([^']+)'", sql, flags=re.I).group(1) # type: ignore[union-attr] row = self.projects.get(pid) - if row and row["owner_sub"] == sub: + if row: self.projects.pop(pid, None) return {"result_type": "CommandOk", "result": []} + # --------------------------------------------------------------- + # Project Members + # --------------------------------------------------------------- + + # INSERT member + if sql_l.startswith("insert into amicable_meta.project_members"): + vals = re.search(r"values\s*\((.*)\)\s*on conflict", sql, flags=re.I | re.S) + if vals: + parts = [p.strip() for p in vals.group(1).split(",")] + pid = parts[0].strip("'") + user_sub = parts[1].strip("'") if parts[1].strip() != "NULL" else None + user_email = parts[2].strip("'").lower() + added_by = ( + parts[3].strip("'") + if len(parts) > 3 and parts[3].strip() != "NULL" + else None + ) + # Use user_sub as key if available, else email + key = (pid, user_sub or user_email) + if key not in self.members: + self.members[key] = { + "project_id": pid, + "user_sub": user_sub, + "user_email": user_email, + "added_at": "t0", + "added_by_sub": added_by, + } + return {"result_type": "CommandOk", "result": []} + + # SELECT members by project_id + if "from amicable_meta.project_members" in sql_l and sql_l.startswith("select"): + pid_match = re.search(r"where project_id\s*=\s*'([^']+)'", sql, flags=re.I) + if pid_match: + pid = pid_match.group(1) + + # Check for is_project_member query (SELECT 1 with user_sub/email OR) + if sql_l.startswith("select 1") and "(user_sub" in sql_l: + # Parse the user_sub and user_email from the OR clause + sub_match = re.search(r"user_sub\s*=\s*'([^']+)'", sql, flags=re.I) + email_or_match = re.search( + r"or user_email\s*=\s*'([^']+)'", sql, flags=re.I + ) + user_sub = sub_match.group(1) if sub_match else None + user_email = email_or_match.group(1).lower() if email_or_match else None + + # Check if user is a member (by sub or email) + is_member = any( + m["project_id"] == pid + and (m.get("user_sub") == user_sub or m.get("user_email") == user_email) + for m in self.members.values() + ) + if is_member: + return {"result_type": "TuplesOk", "result": [["1"], [1]]} + return {"result_type": "TuplesOk", "result": [["1"]]} + + # Check for email filter (for _get_member_by_email) + email_match = re.search( + r"and user_email\s*=\s*'([^']+)'", sql, flags=re.I + ) + if email_match: + email = email_match.group(1).lower() + rows = [ + m + for m in self.members.values() + if m["project_id"] == pid and m["user_email"] == email + ] + else: + rows = [m for m in self.members.values() if m["project_id"] == pid] + + # SELECT 1 without complex filter (simple membership check) + if sql_l.startswith("select 1"): + if rows: + return {"result_type": "TuplesOk", "result": [["1"], [1]]} + return {"result_type": "TuplesOk", "result": [["1"]]} + + header = [ + "project_id", + "user_sub", + "user_email", + "added_at", + "added_by_sub", + ] + out = [header] + for m in rows: + out.append( + [ + m["project_id"], + m["user_sub"], + m["user_email"], + m["added_at"], + m["added_by_sub"], + ] + ) + return {"result_type": "TuplesOk", "result": out} + + # DELETE member + if sql_l.startswith("delete from amicable_meta.project_members"): + pid_match = re.search(r"where project_id\s*=\s*'([^']+)'", sql, flags=re.I) + sub_match = re.search(r"and user_sub\s*=\s*'([^']+)'", sql, flags=re.I) + email_match = re.search(r"and user_email\s*=\s*'([^']+)'", sql, flags=re.I) + if pid_match and sub_match: + key = (pid_match.group(1), sub_match.group(1)) + self.members.pop(key, None) + elif pid_match and email_match: + target_pid = pid_match.group(1) + target_email = email_match.group(1).lower() + to_remove = [ + k for k, v in self.members.items() + if v["project_id"] == target_pid and v["user_email"] == target_email + ] + for k in to_remove: + self.members.pop(k, None) + return {"result_type": "CommandOk", "result": []} + raise AssertionError(f"Unhandled SQL in test fake: {sql}") def metadata(self, payload: dict): @@ -287,3 +529,208 @@ def test_ensure_project_for_id_owner_mismatch() -> None: with pytest.raises(PermissionError): ensure_project_for_id(c, owner=owner2, project_id="abc-123") + + +def test_project_members_table_created() -> None: + """Verify project_members table is created during schema migration.""" + from src.projects import store + + # Reset schema state for this test + store._schema_ready = False + + c = FakeHasuraClient() + from src.projects.store import ensure_projects_schema + + ensure_projects_schema(c) + # The fake client should have received CREATE TABLE for project_members + assert c.schema_created_tables.get("project_members") is True + + +def test_add_and_list_project_members() -> None: + """Test adding and listing project members.""" + c = FakeHasuraClient() + owner = ProjectOwner(sub="u1", email="u1@example.com") + p = create_project(c, owner=owner, name="Shared Project") + + from src.projects.store import add_project_member, list_project_members + + # Creator should already be a member + members = list_project_members(c, project_id=p.project_id) + assert len(members) == 1 + assert members[0].user_sub == "u1" + + # Add another member + add_project_member( + c, project_id=p.project_id, user_email="u2@example.com", added_by_sub="u1" + ) + members = list_project_members(c, project_id=p.project_id) + assert len(members) == 2 + emails = {m.user_email for m in members} + assert emails == {"u1@example.com", "u2@example.com"} + + +def test_shared_project_access() -> None: + """Users can access projects they're members of, even if not the creator.""" + c = FakeHasuraClient() + owner = ProjectOwner(sub="u1", email="u1@example.com") + other = ProjectOwner(sub="u2", email="u2@example.com") + + p = create_project(c, owner=owner, name="Shared Project") + + # Other user cannot access yet + assert get_project_by_id(c, owner=other, project_id=p.project_id) is None + + # Add other as member + from src.projects.store import add_project_member + + add_project_member( + c, + project_id=p.project_id, + user_sub="u2", + user_email="u2@example.com", + added_by_sub="u1", + ) + + # Now other can access + got = get_project_by_id(c, owner=other, project_id=p.project_id) + assert got is not None + assert got.project_id == p.project_id + + # And it appears in their list + lst = list_projects(c, owner=other) + assert any(proj.project_id == p.project_id for proj in lst) + + +def test_project_locking() -> None: + """Test session locking prevents concurrent access.""" + c = FakeHasuraClient() + owner1 = ProjectOwner(sub="u1", email="u1@example.com") + + p = create_project(c, owner=owner1, name="Lockable") + + # Add u2 as member + from src.projects.store import ( + acquire_project_lock, + add_project_member, + get_project_lock, + release_project_lock, + ) + + add_project_member( + c, + project_id=p.project_id, + user_sub="u2", + user_email="u2@example.com", + added_by_sub="u1", + ) + + # u1 acquires lock + lock = acquire_project_lock( + c, project_id=p.project_id, user_sub="u1", user_email="u1@example.com" + ) + assert lock is not None + assert lock.locked_by_sub == "u1" + + # u2 cannot acquire (without force) + lock2 = acquire_project_lock( + c, project_id=p.project_id, user_sub="u2", user_email="u2@example.com" + ) + assert lock2 is None + + # Check who has lock + current = get_project_lock(c, project_id=p.project_id) + assert current is not None + assert current.locked_by_sub == "u1" + + # u1 releases + release_project_lock(c, project_id=p.project_id, user_sub="u1") + + # Now u2 can acquire + lock3 = acquire_project_lock( + c, project_id=p.project_id, user_sub="u2", user_email="u2@example.com" + ) + assert lock3 is not None + assert lock3.locked_by_sub == "u2" + + +def test_project_lock_force_claim() -> None: + """Test that force=True allows taking over a lock.""" + c = FakeHasuraClient() + owner1 = ProjectOwner(sub="u1", email="u1@example.com") + + p = create_project(c, owner=owner1, name="Forceable") + + from src.projects.store import ( + acquire_project_lock, + add_project_member, + get_project_lock, + ) + + add_project_member( + c, + project_id=p.project_id, + user_sub="u2", + user_email="u2@example.com", + added_by_sub="u1", + ) + + # u1 acquires lock + acquire_project_lock( + c, project_id=p.project_id, user_sub="u1", user_email="u1@example.com" + ) + + # u2 force-claims + lock = acquire_project_lock( + c, + project_id=p.project_id, + user_sub="u2", + user_email="u2@example.com", + force=True, + ) + assert lock is not None + assert lock.locked_by_sub == "u2" + + # Verify lock is now held by u2 + current = get_project_lock(c, project_id=p.project_id) + assert current is not None + assert current.locked_by_sub == "u2" + + +def test_member_can_delete_project() -> None: + """Any member can delete a project.""" + c = FakeHasuraClient() + owner = ProjectOwner(sub="u1", email="u1@example.com") + other = ProjectOwner(sub="u2", email="u2@example.com") + + from src.projects.store import add_project_member + + p = create_project(c, owner=owner, name="Deletable") + add_project_member( + c, + project_id=p.project_id, + user_sub="u2", + user_email="u2@example.com", + added_by_sub="u1", + ) + + # u2 (not creator) deletes + mark_project_deleted(c, owner=other, project_id=p.project_id) + + # Verify deleted + assert get_project_by_id(c, owner=owner, project_id=p.project_id) is None + + +def test_non_member_cannot_delete_project() -> None: + """Non-members cannot delete a project.""" + c = FakeHasuraClient() + owner = ProjectOwner(sub="u1", email="u1@example.com") + non_member = ProjectOwner(sub="u3", email="u3@example.com") + + p = create_project(c, owner=owner, name="Protected") + + # u3 (not a member) tries to delete + with pytest.raises(PermissionError): + mark_project_deleted(c, owner=non_member, project_id=p.project_id) + + # Verify not deleted + assert get_project_by_id(c, owner=owner, project_id=p.project_id) is not None diff --git a/tests/test_ws_member_api.py b/tests/test_ws_member_api.py new file mode 100644 index 0000000..3777eb2 --- /dev/null +++ b/tests/test_ws_member_api.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from unittest.mock import patch + +import pytest + +pytest.importorskip("fastapi") +pytest.importorskip("dotenv") + + +def _patch_auth_rejected(): + """Patch _get_owner_from_request to always reject, regardless of AUTH_MODE.""" + return patch( + "src.runtimes.ws_server._get_owner_from_request", + side_effect=PermissionError("not authenticated"), + ) + + +def test_list_members_requires_auth(): + """GET /api/projects/{id}/members requires authentication.""" + from fastapi.testclient import TestClient + + from src.runtimes.ws_server import app + + with patch("src.runtimes.ws_server._require_hasura"), _patch_auth_rejected(): + client = TestClient(app) + resp = client.get("/api/projects/test-id/members") + assert resp.status_code == 401 + + +def test_add_member_requires_auth(): + """POST /api/projects/{id}/members requires authentication.""" + from fastapi.testclient import TestClient + + from src.runtimes.ws_server import app + + with patch("src.runtimes.ws_server._require_hasura"), _patch_auth_rejected(): + client = TestClient(app) + resp = client.post("/api/projects/test-id/members", json={"email": "test@example.com"}) + assert resp.status_code == 401 + + +def test_remove_member_requires_auth(): + """DELETE /api/projects/{id}/members/{sub} requires authentication.""" + from fastapi.testclient import TestClient + + from src.runtimes.ws_server import app + + with patch("src.runtimes.ws_server._require_hasura"), _patch_auth_rejected(): + client = TestClient(app) + resp = client.delete("/api/projects/test-id/members/some-sub") + assert resp.status_code == 401