Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 38 additions & 8 deletions api/controllers/openapi/oauth_device_sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
from __future__ import annotations

import logging
import re
import secrets
from dataclasses import dataclass
from urllib.parse import urlencode

from flask import jsonify, make_response, redirect, request
from pydantic import ValidationError
Expand Down Expand Up @@ -74,6 +76,21 @@
# Canonical sso-complete path. IdP-side ACS callback URL must point here.
_SSO_COMPLETE_PATH = "/openapi/v1/oauth/device/sso-complete"

_ALLOWED_SSO_ERRORS = {"sso_failed", "email_belongs_to_dify_account"}

# user_code only ever reaches the redirect as a urlencoded query value; the
# charset bound additionally forbids the path/scheme separators a redirection
# attack would need, so an untrusted value cannot escape the fixed /device path.
_USER_CODE_RE = re.compile(r"\A[A-Z0-9-]{1,16}\Z")


def _device_error_redirect(code: str, user_code: str | None = None):
safe_code = code if code in _ALLOWED_SSO_ERRORS else "sso_failed"
params: dict[str, str] = {"sso_error": safe_code}
if user_code and _USER_CODE_RE.match(user_code):
params["user_code"] = user_code
return redirect(f"/device?{urlencode(params)}", code=302)
Comment thread
GareArc marked this conversation as resolved.
Dismissed


def _trusted_origin() -> str:
base = (dify_config.CONSOLE_API_URL or "").rstrip("/")
Expand Down Expand Up @@ -134,43 +151,56 @@ def sso_initiate():
@bp.route("/oauth/device/sso-complete", methods=["GET"])
@enterprise_only
def sso_complete():
try:
return _sso_complete_impl()
except Exception:
logger.exception("sso-complete: unhandled")
return _device_error_redirect("sso_failed")


def _sso_complete_impl():
inbound_error = request.args.get("sso_error")
if inbound_error:
return _device_error_redirect(inbound_error, request.args.get("user_code"))

blob = request.args.get("sso_assertion")
if not blob:
raise BadRequest("sso_assertion required")
return _device_error_redirect("sso_failed")

keyset = jws.KeySet.from_shared_secret()

try:
raw_claims = jws.verify(keyset, blob, expected_aud=jws.AUD_EXT_SUBJECT_ASSERTION)
except jws.VerifyError as e:
logger.warning("sso-complete: rejected assertion: %s", e)
raise BadRequest("invalid_sso_assertion") from e
return _device_error_redirect("sso_failed")

try:
claims = ExtSubjectAssertionClaims.model_validate(raw_claims)
except ValidationError as e:
logger.warning("sso-complete: claim shape invalid: %s", e)
raise BadRequest("invalid_sso_assertion") from e
return _device_error_redirect("sso_failed")

user_code = claims.user_code.strip().upper()

if not consume_sso_assertion_nonce(redis_client, claims.nonce):
raise BadRequest("invalid_sso_assertion")
return _device_error_redirect("sso_failed", user_code)

user_code = claims.user_code.strip().upper()
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(user_code)
if found is None:
raise Conflict("user_code_not_pending")
return _device_error_redirect("sso_failed", user_code)
_, state = found
if state.status is not DeviceFlowStatus.PENDING:
raise Conflict("user_code_not_pending")
return _device_error_redirect("sso_failed", user_code)

if AccountService.has_active_account_with_email(db.session, claims.email):
_emit_external_rejection_audit(
state,
_RejectedClaims(subject_email=claims.email, subject_issuer=claims.issuer),
reason="email_belongs_to_dify_account",
)
return redirect("/device?sso_error=email_belongs_to_dify_account", code=302)
return _device_error_redirect("email_belongs_to_dify_account", user_code)

iss = _trusted_origin()
cookie_value, _ = mint_approval_grant(
Expand Down
95 changes: 95 additions & 0 deletions api/tests/unit_tests/controllers/openapi/test_device_sso.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""SSO-branch device-flow endpoints under /openapi/v1/oauth/device/."""

import builtins
from unittest.mock import MagicMock, patch

import pytest
from flask import Flask
Expand Down Expand Up @@ -77,3 +78,97 @@ def test_sso_complete_idp_callback_url_uses_canonical_path():
from controllers.openapi import oauth_device_sso

assert oauth_device_sso._SSO_COMPLETE_PATH == "/openapi/v1/oauth/device/sso-complete"


# ---------------------------------------------------------------------------
# _device_error_redirect helper
# ---------------------------------------------------------------------------


def test_device_error_redirect_builds_relative_location():
from controllers.openapi import oauth_device_sso

app = Flask(__name__)
with app.test_request_context():
resp = oauth_device_sso._device_error_redirect("sso_failed", "ABCD-1234")
assert resp.status_code == 302
loc = resp.headers["Location"]
assert loc.startswith("/device?")
assert "sso_error=sso_failed" in loc
assert "user_code=ABCD-1234" in loc


def test_device_error_redirect_clamps_unknown_code():
from controllers.openapi import oauth_device_sso

app = Flask(__name__)
with app.test_request_context():
resp = oauth_device_sso._device_error_redirect("totally-bogus")
assert "sso_error=sso_failed" in resp.headers["Location"]


def test_device_error_redirect_keeps_email_special_case():
from controllers.openapi import oauth_device_sso

app = Flask(__name__)
with app.test_request_context():
resp = oauth_device_sso._device_error_redirect("email_belongs_to_dify_account", "ABCD-1234")
assert "sso_error=email_belongs_to_dify_account" in resp.headers["Location"]


def test_device_error_redirect_omits_empty_user_code():
from controllers.openapi import oauth_device_sso

app = Flask(__name__)
with app.test_request_context():
resp = oauth_device_sso._device_error_redirect("sso_failed")
assert "user_code=" not in resp.headers["Location"]


def test_device_error_redirect_drops_malformed_user_code():
from controllers.openapi import oauth_device_sso

app = Flask(__name__)
with app.test_request_context():
resp = oauth_device_sso._device_error_redirect("sso_failed", "https://evil.example/")
loc = resp.headers["Location"]
assert loc.startswith("/device?")
assert "user_code=" not in loc
assert "evil" not in loc


# ---------------------------------------------------------------------------
# sso_complete redirect behaviour
# ---------------------------------------------------------------------------


def _ee_features():
from services.feature_service import LicenseStatus

m = MagicMock()
m.license.status = LicenseStatus.ACTIVE
return m


@patch("libs.device_flow_security.FeatureService.get_system_features")
def test_sso_complete_relays_inbound_sso_error(ee_feat, openapi_app):
ee_feat.return_value = _ee_features()
client = openapi_app.test_client()
resp = client.get(
"/openapi/v1/oauth/device/sso-complete?sso_error=sso_failed&user_code=ABCD-1234",
follow_redirects=False,
)
assert resp.status_code == 302
loc = resp.headers["Location"]
assert "/device?" in loc
assert "sso_error=sso_failed" in loc
assert "user_code=ABCD-1234" in loc


@patch("libs.device_flow_security.FeatureService.get_system_features")
def test_sso_complete_missing_assertion_redirects_generic(ee_feat, openapi_app):
ee_feat.return_value = _ee_features()
client = openapi_app.test_client()
resp = client.get("/openapi/v1/oauth/device/sso-complete", follow_redirects=False)
assert resp.status_code == 302
assert "sso_error=sso_failed" in resp.headers["Location"]
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ def test_sso_complete_rejects_assertion_missing_email(ee_feat, jws_mod, app: Fla
jws_mod.VerifyError = Exception

client = app.test_client()
resp = client.get("/openapi/v1/oauth/device/sso-complete?sso_assertion=blob")
assert resp.status_code == 400, resp.data
resp = client.get("/openapi/v1/oauth/device/sso-complete?sso_assertion=blob", follow_redirects=False)
assert resp.status_code == 302, resp.data
assert "sso_error=sso_failed" in resp.headers["Location"]


@patch("controllers.openapi.oauth_device_sso.jws")
Expand All @@ -48,8 +49,9 @@ def test_sso_complete_rejects_assertion_empty_issuer(ee_feat, jws_mod, app: Flas
jws_mod.VerifyError = Exception

client = app.test_client()
resp = client.get("/openapi/v1/oauth/device/sso-complete?sso_assertion=blob")
assert resp.status_code == 400
resp = client.get("/openapi/v1/oauth/device/sso-complete?sso_assertion=blob", follow_redirects=False)
assert resp.status_code == 302
assert "sso_error=sso_failed" in resp.headers["Location"]


def test_verify_approval_grant_raises_on_missing_field():
Expand Down
53 changes: 33 additions & 20 deletions web/app/device/__tests__/page-terminal.spec.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -121,40 +121,53 @@ describe('error_lookup_failed terminal state', () => {
})
})

describe('sso_error inline banner on the code-entry page', () => {
const SSO_BANNER_COPY = 'deviceFlow.ssoError.emailBelongsToDifyAccount'

it('shows the error banner with friendly copy when sso_error is present', async () => {
mockSearchParams = { sso_error: 'email_belongs_to_dify_account' }
describe('error_sso dedicated view', () => {
const TITLE = 'deviceFlow.errorSso.title'
const GENERIC = 'deviceFlow.ssoError.default'
const EMAIL_COPY = 'deviceFlow.ssoError.emailBelongsToDifyAccount'
const BACK_TO_LOGIN = 'deviceFlow.errorSso.backToLoginOptions'

it('renders the dedicated SSO error screen (not the code-entry page)', async () => {
mockSearchParams = { sso_error: 'sso_failed', user_code: 'ABCD-3456' }
render(<DevicePage />)
expect(await screen.findByText(SSO_BANNER_COPY)).toBeInTheDocument()
expect(await screen.findByText(TITLE)).toBeInTheDocument()
expect(await screen.findByText(GENERIC)).toBeInTheDocument()
expect(screen.queryByRole('textbox')).not.toBeInTheDocument()
})

it('keeps the code-entry screen visible (error on main page, not a separate view)', async () => {
mockSearchParams = { sso_error: 'email_belongs_to_dify_account' }
it('shows the email special-case copy', async () => {
mockSearchParams = { sso_error: 'email_belongs_to_dify_account', user_code: 'ABCD-3456' }
render(<DevicePage />)
await screen.findByText(SSO_BANNER_COPY)
expect(screen.getByRole('textbox')).toBeInTheDocument()
expect(screen.getByRole('button', { name: /deviceFlow.codeEntry.continue/i })).toBeInTheDocument()
expect(await screen.findByText(EMAIL_COPY)).toBeInTheDocument()
})

it('does not surface the raw backend error code', async () => {
mockSearchParams = { sso_error: 'email_belongs_to_dify_account' }
it('never surfaces the raw backend code', async () => {
mockSearchParams = { sso_error: 'email_belongs_to_dify_account', user_code: 'ABCD-3456' }
render(<DevicePage />)
await screen.findByText(SSO_BANNER_COPY)
await screen.findByText(EMAIL_COPY)
expect(screen.queryByText('email_belongs_to_dify_account')).not.toBeInTheDocument()
})

it('does not scrub the param on mount (regression: error was wiped by router.replace)', async () => {
mockSearchParams = { sso_error: 'email_belongs_to_dify_account' }
it('scrubs sso_error + user_code from the URL on mount', async () => {
mockSearchParams = { sso_error: 'sso_failed', user_code: 'ABCD-3456' }
render(<DevicePage />)
await screen.findByText(TITLE)
expect(mockReplace).toHaveBeenCalledWith('/device')
})

it('"Back to login options" re-checks the code and advances to the chooser', async () => {
mockSearchParams = { sso_error: 'sso_failed', user_code: 'ABCD-3456' }
mockDeviceLookup.mockResolvedValue({ valid: true })
render(<DevicePage />)
await screen.findByText(SSO_BANNER_COPY)
expect(mockReplace).not.toHaveBeenCalled()
await screen.findByText(TITLE)
fireEvent.click(screen.getByRole('button', { name: BACK_TO_LOGIN }))
await screen.findByText('chooser.subtitle')
expect(mockDeviceLookup).toHaveBeenCalledWith('ABCD-3456')
})

it('shows no banner when sso_error is absent', () => {
it('shows no SSO error screen when sso_error is absent', () => {
render(<DevicePage />)
expect(screen.getByRole('textbox')).toBeInTheDocument()
expect(screen.queryByText(SSO_BANNER_COPY)).not.toBeInTheDocument()
expect(screen.queryByText(TITLE)).not.toBeInTheDocument()
})
})
55 changes: 42 additions & 13 deletions web/app/device/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type View
| { kind: 'error_expired' }
| { kind: 'error_rate_limited' }
| { kind: 'error_lookup_failed' }
| { kind: 'error_sso', code: string, userCode: string }

export default function DevicePage() {
const { t } = useTranslation('deviceFlow')
Expand Down Expand Up @@ -74,6 +75,11 @@ export default function DevicePage() {
useEffect(() => {
if (view.kind !== 'code_entry' && view.kind !== 'chooser')
return
if (ssoError) {
setView({ kind: 'error_sso', code: ssoError, userCode: urlUserCode }) // eslint-disable-line react/set-state-in-effect
router.replace(pathname)
return
}
// Post-login bounce: chooser holds the typed code, account just loaded.
// The URL was already scrubbed on the first effect run, so urlUserCode
// is empty here — advance using the userCode stashed in view state.
Expand All @@ -95,13 +101,11 @@ export default function DevicePage() {
}
if (consumed && (urlUserCode || ssoVerified))
router.replace(pathname)
}, [urlUserCode, ssoVerified, account, view, router, pathname])
}, [urlUserCode, ssoVerified, ssoError, account, view, router, pathname])

const onContinue = async () => {
if (!isValidUserCode(typed))
return
const advanceFromCode = async (code: string) => {
try {
const reply = await deviceLookup(typed)
const reply = await deviceLookup(code)
if (!reply.valid) {
setView({ kind: 'error_expired' })
return
Expand All @@ -118,20 +122,20 @@ export default function DevicePage() {
return
}
if (account)
setView({ kind: 'authorize_account', userCode: typed })
else setView({ kind: 'chooser', userCode: typed })
setView({ kind: 'authorize_account', userCode: code })
else setView({ kind: 'chooser', userCode: code })
}

const onContinue = async () => {
if (!isValidUserCode(typed))
return
await advanceFromCode(typed)
}

return (
<>
{view.kind === 'code_entry' && (
<div className="flex flex-col gap-5">
{ssoError && (
<div className="flex items-start gap-2 rounded-lg bg-state-destructive-hover p-3">
<span className="mt-0.5 i-ri-close-circle-line h-4 w-4 shrink-0 text-util-colors-red-red-600" />
<p className="text-sm text-text-destructive">{ssoErrorCopy(ssoError, t)}</p>
</div>
)}
<div>
<h1 className="text-2xl font-semibold text-text-primary">{t('codeEntry.title')}</h1>
<p className="mt-2 text-sm text-text-secondary">
Expand Down Expand Up @@ -273,6 +277,31 @@ export default function DevicePage() {
</div>
)}

{view.kind === 'error_sso' && (
<div className="flex flex-col gap-1">
<div className="mb-2.5 flex h-[38px] w-[38px] items-center justify-center rounded-full bg-state-warning-hover">
<span aria-hidden="true" className="i-ri-error-warning-line h-[18px] w-[18px] text-util-colors-yellow-yellow-600" />
</div>
<h1 className="text-xl font-semibold text-text-primary">{t('errorSso.title')}</h1>
<p className="text-sm text-text-secondary">{ssoErrorCopy(view.code, t)}</p>
<Divider className="my-3" />
<Button
variant="primary"
size="large"
className="w-full"
onClick={() => {
setErrMsg(null)
if (view.userCode)
advanceFromCode(view.userCode)
else
setView({ kind: 'code_entry' })
}}
>
{t('errorSso.backToLoginOptions')}
</Button>
</div>
)}

{errMsg && (
<p className="mt-4 text-sm text-text-destructive">{errMsg}</p>
)}
Expand Down
Loading
Loading