From 5dcbb05b91b437b8d30e0f4b267765dbb67f5771 Mon Sep 17 00:00:00 2001 From: Ishaan Gupta Date: Sat, 2 May 2026 00:02:13 +0530 Subject: [PATCH] implement OAuth 2.0 for mcp --- src/api/routes/auth.py | 96 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) diff --git a/src/api/routes/auth.py b/src/api/routes/auth.py index 6851db6..a4a70b1 100644 --- a/src/api/routes/auth.py +++ b/src/api/routes/auth.py @@ -76,6 +76,36 @@ def _get_and_invalidate_mcp_token(token: str) -> Optional[str]: return user_id +# ═══════════════════════════════════════════════════════════════════════════ +# Standard OAuth 2.0 Store (for ChatGPT UI) +# ═══════════════════════════════════════════════════════════════════════════ +_oauth_auth_codes: Dict[str, Dict[str, Any]] = {} + +def _generate_auth_code(user_id: str) -> str: + """Generate a standard OAuth 2.0 authorization code.""" + alphabet = string.ascii_letters + string.digits + code = "".join(secrets.choice(alphabet) for _ in range(32)) + + _oauth_auth_codes[code] = { + "user_id": user_id, + "expires_at": datetime.utcnow() + timedelta(minutes=10) + } + return code + +def _get_and_invalidate_auth_code(code: str) -> Optional[str]: + """Validate auth code and return user_id if valid.""" + if code not in _oauth_auth_codes: + return None + + data = _oauth_auth_codes[code] + del _oauth_auth_codes[code] # Single-use + + if datetime.utcnow() > data["expires_at"]: + return None + + return data["user_id"] + + # ═══════════════════════════════════════════════════════════════════════════ # Pydantic Models # ═══════════════════════════════════════════════════════════════════════════ @@ -142,6 +172,16 @@ class MCPExchangeResponse(BaseModel): user: dict +class OAuthApproveRequest(BaseModel): + """Request from frontend to approve OAuth and get a code.""" + client_id: str + redirect_uri: str + +class OAuthApproveResponse(BaseModel): + """Response with the authorization code.""" + code: str + + # ═══════════════════════════════════════════════════════════════════════════ # JWT Utilities # ═══════════════════════════════════════════════════════════════════════════ @@ -475,3 +515,59 @@ async def exchange_mcp_token(request: MCPExchangeRequest): api_key=key_result["key"], user=user_response ) + + +# ═══════════════════════════════════════════════════════════════════════════ +# Standard OAuth 2.0 Routes (For ChatGPT UI) +# ═══════════════════════════════════════════════════════════════════════════ + +@router.post("/oauth/approve", response_model=OAuthApproveResponse) +async def oauth_approve(request: OAuthApproveRequest, current_user: dict = Depends(require_user)): + """ + Called by the Next.js frontend when the user clicks 'Approve' on the consent screen. + Generates an authorization code for standard OAuth 2.0 flow. + """ + if not current_user: + raise HTTPException(status_code=401, detail="Authentication required") + + user_id = str(current_user.get("id")) + code = _generate_auth_code(user_id) + return OAuthApproveResponse(code=code) + + +from fastapi import Form +from fastapi.responses import JSONResponse + +@router.post("/oauth/token") +async def oauth_token( + grant_type: str = Form(...), + code: str = Form(None), + redirect_uri: str = Form(None), + client_id: str = Form(None) +): + """ + Standard OAuth 2.0 token endpoint. + ChatGPT calls this directly to exchange the authorization code for an access token. + """ + if grant_type != "authorization_code": + return JSONResponse(status_code=400, content={"error": "unsupported_grant_type"}) + + if not code: + return JSONResponse(status_code=400, content={"error": "invalid_request", "error_description": "code is required"}) + + user_id = _get_and_invalidate_auth_code(code) + if not user_id: + return JSONResponse(status_code=400, content={"error": "invalid_grant", "error_description": "Invalid or expired authorization code"}) + + # Generate a permanent API key acting as the access token + key_result = api_key_store.create_api_key( + user_id=user_id, + name=f"OAuth Client ({client_id or 'Unknown'}) - {datetime.utcnow().strftime('%Y-%m-%d')}" + ) + + return { + "access_token": key_result["key"], + "token_type": "Bearer", + "expires_in": 31536000, # 1 year + "scope": "all" + }