refactor: Redesign the auth process

This commit is contained in:
Tuluobo 2024-09-15 19:04:00 +08:00
parent 19d063ae14
commit 1069de389c
16 changed files with 232 additions and 114 deletions

View File

@ -35,6 +35,7 @@ CREATE TABLE "codes" (
"id" TEXT NOT NULL, "id" TEXT NOT NULL,
"code" TEXT NOT NULL, "code" TEXT NOT NULL,
"expiresAt" TIMESTAMP(3) NOT NULL, "expiresAt" TIMESTAMP(3) NOT NULL,
"deletedAt" TIMESTAMP(3),
"userId" TEXT NOT NULL, "userId" TEXT NOT NULL,
"clientId" TEXT NOT NULL, "clientId" TEXT NOT NULL,
@ -52,6 +53,18 @@ CREATE TABLE "access_tokens" (
CONSTRAINT "access_tokens_pkey" PRIMARY KEY ("id") CONSTRAINT "access_tokens_pkey" PRIMARY KEY ("id")
); );
-- CreateTable
CREATE TABLE "authorizations" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"userId" TEXT NOT NULL,
"clientId" TEXT NOT NULL,
"scope" TEXT,
CONSTRAINT "authorizations_pkey" PRIMARY KEY ("id")
);
-- CreateIndex -- CreateIndex
CREATE UNIQUE INDEX "users_username_key" ON "users"("username"); CREATE UNIQUE INDEX "users_username_key" ON "users"("username");
@ -67,6 +80,9 @@ CREATE UNIQUE INDEX "codes_code_key" ON "codes"("code");
-- CreateIndex -- CreateIndex
CREATE UNIQUE INDEX "access_tokens_token_key" ON "access_tokens"("token"); CREATE UNIQUE INDEX "access_tokens_token_key" ON "access_tokens"("token");
-- CreateIndex
CREATE UNIQUE INDEX "authorizations_userId_clientId_key" ON "authorizations"("userId", "clientId");
-- AddForeignKey -- AddForeignKey
ALTER TABLE "clients" ADD CONSTRAINT "clients_userId_fkey" FOREIGN KEY ("userId") REFERENCES "users"("id") ON DELETE RESTRICT ON UPDATE CASCADE; ALTER TABLE "clients" ADD CONSTRAINT "clients_userId_fkey" FOREIGN KEY ("userId") REFERENCES "users"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
@ -81,3 +97,9 @@ ALTER TABLE "access_tokens" ADD CONSTRAINT "access_tokens_userId_fkey" FOREIGN K
-- AddForeignKey -- AddForeignKey
ALTER TABLE "access_tokens" ADD CONSTRAINT "access_tokens_clientId_fkey" FOREIGN KEY ("clientId") REFERENCES "clients"("id") ON DELETE RESTRICT ON UPDATE CASCADE; ALTER TABLE "access_tokens" ADD CONSTRAINT "access_tokens_clientId_fkey" FOREIGN KEY ("clientId") REFERENCES "clients"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "authorizations" ADD CONSTRAINT "authorizations_userId_fkey" FOREIGN KEY ("userId") REFERENCES "users"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "authorizations" ADD CONSTRAINT "authorizations_clientId_fkey" FOREIGN KEY ("clientId") REFERENCES "clients"("id") ON DELETE RESTRICT ON UPDATE CASCADE;

View File

@ -24,6 +24,8 @@ model User {
codes Code[] codes Code[]
accessTokens AccessToken[] accessTokens AccessToken[]
authorizations Authorization[]
createdAt DateTime @default(now()) @map(name: "created_at") createdAt DateTime @default(now()) @map(name: "created_at")
updatedAt DateTime @default(now()) @map(name: "updated_at") updatedAt DateTime @default(now()) @map(name: "updated_at")
@ -47,13 +49,16 @@ model Client {
authCodes Code[] authCodes Code[]
accessTokens AccessToken[] accessTokens AccessToken[]
authorizations Authorization[]
@@map("clients") @@map("clients")
} }
model Code { model Code {
id String @id @default(cuid()) id String @id @default(cuid())
code String @unique code String @unique
expiresAt DateTime expiresAt DateTime
deletedAt DateTime?
userId String userId String
user User @relation(fields: [userId], references: [id]) user User @relation(fields: [userId], references: [id])
@ -77,3 +82,20 @@ model AccessToken {
@@map("access_tokens") @@map("access_tokens")
} }
model Authorization {
id String @id @default(cuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
userId String
user User @relation(fields: [userId], references: [id])
clientId String
client Client @relation(fields: [clientId], references: [id])
scope String?
@@unique([userId, clientId])
@@map("authorizations")
}

View File

@ -0,0 +1,23 @@
"use server";
import { createAuthorization } from "@/lib/dto/authorization";
import { getAuthorizeUrl } from "@/lib/oauth/authorize-url";
export async function handleAuthorizeAction(
oauth: string,
userId: string,
clientId: string,
scope: string,
) {
const oauthParams = new URLSearchParams(atob(oauth));
const redirectUrl = getAuthorizeUrl(oauthParams);
// 保存授权
await createAuthorization({
userId,
clientId,
scope,
});
return redirectUrl;
}

View File

@ -1,45 +0,0 @@
"use server";
import WordArray from "crypto-js/lib-typedarrays";
import { verify } from "@/lib/discourse-verify";
import { getClientByClientId } from "@/lib/dto/client";
import { createCode } from "@/lib/dto/code";
export async function handleDiscourseCallbackAction(searchParams: string) {
const params = new URLSearchParams(searchParams);
const sig = params.get("sig") as string;
const sso = params.get("sso") as string;
const oauth = params.get("oauth") as string;
const user = await verify(sso, sig);
// code redirect ...
const oauthParams = new URLSearchParams(atob(oauth));
const client = await getClientByClientId(
oauthParams.get("client_id") as string,
);
if (!client) {
throw new Error("Client Id invalid (code: -1004).");
}
const redirect_uri = new URL(client.redirectUri);
if (oauthParams.has("state")) {
redirect_uri.searchParams.append("state", oauthParams.get("state") || "");
}
const code = WordArray.random(32).toString();
redirect_uri.searchParams.append("code", code);
// storage
try {
await createCode({
code,
expiresAt: new Date(Date.now() + 10 * 60 * 1000),
clientId: client.id,
userId: user.id,
});
} catch {
throw new Error("Create code error (code: -1005).");
}
return redirect_uri.toString();
}

View File

@ -1,13 +1,51 @@
import { Suspense } from "react"; import { Suspense } from "react";
import { redirect } from "next/navigation";
import { Authorizing } from "@/components/auth/authorizing"; import { discourseCallbackVerify } from "@/lib/discourse/verify";
import { findAuthorization } from "@/lib/dto/authorization";
import { getClientByClientId } from "@/lib/dto/client";
import { getAuthorizeUrl } from "@/lib/oauth/authorize-url";
import { AuthorizationCard } from "@/components/auth/authorization-card";
export interface DiscourseCallbackParams extends Record<string, string> {
sig: string;
sso: string;
oauth: string;
}
export default async function DiscourseCallbackPage({
searchParams,
}: {
searchParams: DiscourseCallbackParams;
}) {
const oauthParams = new URLSearchParams(atob(searchParams.oauth));
// check client info
const client = await getClientByClientId(
oauthParams.get("client_id") as string,
);
if (!client) {
throw new Error("Client Id invalid (code: -1004).");
}
// verify discourse callback
const user = await discourseCallbackVerify(
searchParams.sso,
searchParams.sig,
);
// check authorization
const authorization = await findAuthorization(user.id, client.id);
if (authorization) {
const redirectUrl = await getAuthorizeUrl(oauthParams);
return redirect(redirectUrl);
}
export default function AuthPage() {
return ( return (
<main className="flex min-h-screen flex-col items-center justify-center"> <main className="flex min-h-screen flex-col items-center justify-center">
<div className="flex items-center justify-center"> <div className="flex items-center justify-center">
<Suspense> <Suspense>
<Authorizing /> <AuthorizationCard client={client} oauthParams={searchParams.oauth} />
</Suspense> </Suspense>
</div> </div>
</main> </main>

View File

@ -1,8 +1,5 @@
import { redirect } from "next/navigation";
import { getClientByClientId } from "@/lib/dto/client"; import { getClientByClientId } from "@/lib/dto/client";
import { prisma } from "@/lib/prisma"; import { Authorizing } from "@/components/auth/authorizing";
import { AuthorizationCard } from "@/components/auth/authorization-card";
export interface AuthorizeParams extends Record<string, string> { export interface AuthorizeParams extends Record<string, string> {
scope: string; scope: string;
@ -16,6 +13,7 @@ export default async function OAuthAuthorization({
}: { }: {
searchParams: AuthorizeParams; searchParams: AuthorizeParams;
}) { }) {
// params invalid
if ( if (
!searchParams.response_type || !searchParams.response_type ||
!searchParams.client_id || !searchParams.client_id ||
@ -24,34 +22,16 @@ export default async function OAuthAuthorization({
throw new Error("Params invalid"); throw new Error("Params invalid");
} }
const client = await getClient({ // client invalid
clientId: searchParams.client_id, const client = await getClientByClientId(searchParams.client_id);
redirectUri: searchParams.redirect_uri, if (!client || client.redirectUri !== searchParams.redirect_uri) {
});
if (!client) {
throw new Error("Client not found"); throw new Error("Client not found");
} }
// Authorizing ...
return ( return (
<div className="flex min-h-screen items-center justify-center bg-gray-50 p-4"> <div className="flex min-h-screen items-center justify-center bg-gray-50 p-4">
<AuthorizationCard client={client} /> <Authorizing />
</div> </div>
); );
} }
async function getClient({
clientId,
redirectUri,
}: {
clientId: string;
redirectUri: string;
}) {
const client = await getClientByClientId(clientId);
if (client && client.redirectUri === redirectUri) {
return client;
}
return null;
}

View File

@ -1,31 +1,31 @@
import { NextResponse } from "next/server"; import { NextResponse } from "next/server";
import { createAccessToken } from "@/lib/dto/accessToken"; import { createAccessToken } from "@/lib/dto/accessToken";
import { deleteCode, getCodeByCode } from "@/lib/dto/code"; import { deleteCode, getUnexpiredCodeByCode } from "@/lib/dto/code";
import { generateRandomKey } from "@/lib/utils"; import { generateRandomKey } from "@/lib/utils";
export async function POST(req: Request) { export async function POST(req: Request) {
const formData = await req.formData(); const formData = await req.formData();
// get code
const code = formData.get("code") as string; const code = formData.get("code") as string;
if (!code) { if (!code) {
console.log(`code: ${code}`); return new NextResponse("Invalid code params.", { status: 400 });
return new NextResponse("Invalid code credentials.", { status: 400 });
} }
const authorizeCode = await getCodeByCode(code); const authorizeCode = await getUnexpiredCodeByCode(code);
await deleteCode(code); await deleteCode(code);
if (!authorizeCode) { if (!authorizeCode) {
console.log(`code: ${code}`);
return new NextResponse("Invalid code credentials.", { status: 400 }); return new NextResponse("Invalid code credentials.", { status: 400 });
} }
// verify redirect uri
if (authorizeCode.client.redirectUri !== formData.get("redirect_uri")) { if (authorizeCode.client.redirectUri !== formData.get("redirect_uri")) {
console.log(
`redirectUri: ${authorizeCode.client.redirectUri} !== formData.get("redirect_uri"): ${formData.get("redirect_uri")}`,
);
return new NextResponse("Invalid redirect uri.", { status: 400 }); return new NextResponse("Invalid redirect uri.", { status: 400 });
} }
// generate access token
const expiresIn = 3600 * 24 * 7; const expiresIn = 3600 * 24 * 7;
const token = "tk_" + generateRandomKey(); const token = "at_" + generateRandomKey(32);
await createAccessToken({ await createAccessToken({
token, token,
expiresAt: new Date(Date.now() + expiresIn * 1000), expiresAt: new Date(Date.now() + expiresIn * 1000),

View File

@ -5,6 +5,7 @@ import { getAccessTokenByToken } from "@/lib/dto/accessToken";
export async function GET(req: Request) { export async function GET(req: Request) {
const authorization = req.headers.get("Authorization"); const authorization = req.headers.get("Authorization");
// verify access token
const token = authorization?.slice(7); // remove "Bearer " const token = authorization?.slice(7); // remove "Bearer "
if (!token) { if (!token) {
return new NextResponse("Invalid access token (code: -1000).", { return new NextResponse("Invalid access token (code: -1000).", {
@ -18,8 +19,8 @@ export async function GET(req: Request) {
}); });
} }
// return user
let user = accessToken.user; let user = accessToken.user;
return Response.json({ return Response.json({
id: user.id, id: user.id,
email: user.email, email: user.email,

View File

@ -1,7 +1,7 @@
import type { NextAuthConfig } from "next-auth"; import type { NextAuthConfig } from "next-auth";
import Credentials from "next-auth/providers/credentials"; import Credentials from "next-auth/providers/credentials";
import { verify } from "./lib/discourse-verify"; import { discourseCallbackVerify } from "./lib/discourse/verify";
// Notice this is only an object, not a full Auth.js instance // Notice this is only an object, not a full Auth.js instance
export default { export default {
@ -14,7 +14,7 @@ export default {
authorize: async (credentials) => { authorize: async (credentials) => {
const sso = credentials.sso as string; const sso = credentials.sso as string;
const sig = credentials.sig as string; const sig = credentials.sig as string;
const user = await verify(sso, sig); const user = await discourseCallbackVerify(sso, sig);
return user; return user;
}, },
}), }),

View File

@ -1,8 +1,8 @@
"use client"; "use client";
import { useState } from "react"; import { useState } from "react";
import { useRouter, useSearchParams } from "next/navigation"; import { useRouter } from "next/navigation";
import { getDiscourseSSOUrl } from "@/actions/discourse-sso-url"; import { handleAuthorizeAction } from "@/actions/authorizing";
import { Client } from "@prisma/client"; import { Client } from "@prisma/client";
import { import {
ChevronsDownUp, ChevronsDownUp,
@ -35,19 +35,29 @@ const permissions: Permission[] = [
}, },
]; ];
export function AuthorizationCard({ client }: { client: Client }) { export function AuthorizationCard({
client,
oauthParams,
}: {
client: Client;
oauthParams: string;
}) {
const [expandedPermission, setExpandedPermission] = useState<string | null>( const [expandedPermission, setExpandedPermission] = useState<string | null>(
null, null,
); );
const router = useRouter(); const router = useRouter();
const searchParams = useSearchParams();
const togglePermission = (id: string) => { const togglePermission = (id: string) => {
setExpandedPermission(expandedPermission === id ? null : id); setExpandedPermission(expandedPermission === id ? null : id);
}; };
const authorizingHandler = async () => { const authorizingHandler = async () => {
const url = await getDiscourseSSOUrl(searchParams.toString()); const url = await handleAuthorizeAction(
oauthParams,
client.userId,
client.id,
permissions[0].id,
);
router.push(url); router.push(url);
}; };

View File

@ -2,41 +2,36 @@
import { useCallback, useEffect, useState } from "react"; import { useCallback, useEffect, useState } from "react";
import { useRouter, useSearchParams } from "next/navigation"; import { useRouter, useSearchParams } from "next/navigation";
import { handleDiscourseCallbackAction } from "@/actions/discourse-callback"; import { getDiscourseSSOUrl } from "@/actions/discourse-sso-url";
export function Authorizing() { export function Authorizing() {
const router = useRouter(); const router = useRouter();
const searchParams = useSearchParams(); const searchParams = useSearchParams();
const [isLoading, setIsLoading] = useState<boolean>(false);
const [error, setError] = useState<unknown | null>(null); const [error, setError] = useState<unknown | null>(null);
const signInCallback = useCallback(async () => { const signInCallback = useCallback(async () => {
if (isLoading) {
return;
}
setIsLoading(true);
try { try {
const url = await handleDiscourseCallbackAction(searchParams.toString()); const url = await getDiscourseSSOUrl(searchParams.toString());
router.push(url); router.push(url);
setIsLoading(false);
} catch (error) { } catch (error) {
setError(error); setError(error);
setIsLoading(false);
} }
}, []); }, []);
useEffect(() => { useEffect(() => {
// Delay 3s get sso url go to ...
const timer = setTimeout(signInCallback, 3); const timer = setTimeout(signInCallback, 3);
return () => { return () => {
clearTimeout(timer); clearTimeout(timer);
}; };
}, []); }, []);
return ( return (
<> <>
{error ? ( {error ? (
<p className="text-center"></p> <p className="text-center"></p>
) : ( ) : (
<p className="text-center"> ...</p> <p className="text-center"> ...</p>
)} )}
</> </>
); );

View File

@ -10,7 +10,7 @@ import { createUser, getUserById, updateUser } from "@/lib/dto/user";
const DISCOUSE_SECRET = process.env.DISCOUSE_SECRET as string; const DISCOUSE_SECRET = process.env.DISCOUSE_SECRET as string;
export async function verify(sso: string, sig: string) { export async function discourseCallbackVerify(sso: string, sig: string) {
// 校验数据正确性 // 校验数据正确性
if (hmacSHA256(sso, DISCOUSE_SECRET).toString(Hex) != sig) { if (hmacSHA256(sso, DISCOUSE_SECRET).toString(Hex) != sig) {
throw new Error("Request params is invalid (code: -1001)."); throw new Error("Request params is invalid (code: -1001).");
@ -25,7 +25,7 @@ export async function verify(sso: string, sig: string) {
if (cookieStore.get(AUTH_NONCE)?.value != nonce) { if (cookieStore.get(AUTH_NONCE)?.value != nonce) {
throw new Error("Request params is invalid (code: -1003)."); throw new Error("Request params is invalid (code: -1003).");
} }
cookieStore.delete(AUTH_NONCE); // cookieStore.delete(AUTH_NONCE);
const id = searchParams.get("external_id"); const id = searchParams.get("external_id");
const email = searchParams.get("email"); const email = searchParams.get("email");

View File

@ -9,8 +9,14 @@ export async function createAccessToken(
} }
export async function getAccessTokenByToken(token: string) { export async function getAccessTokenByToken(token: string) {
return prisma.accessToken.findUnique({ const now = new Date();
where: { token }, return prisma.accessToken.findFirst({
where: {
token,
expiresAt: {
gt: now,
},
},
include: { user: true }, include: { user: true },
}); });
} }

View File

@ -0,0 +1,22 @@
import { Authorization } from "@prisma/client";
import { prisma } from "../prisma";
export async function createAuthorization(
data: Omit<Authorization, "id" | "createdAt" | "updatedAt">,
) {
await prisma.authorization.create({
data,
});
}
export async function findAuthorization(userId: string, clientId: string) {
return await prisma.authorization.findUnique({
where: {
userId_clientId: {
userId,
clientId,
},
},
});
}

View File

@ -2,17 +2,25 @@ import { Code } from "@prisma/client";
import { prisma } from "@/lib/prisma"; import { prisma } from "@/lib/prisma";
export async function createCode(data: Omit<Code, "id">) { export async function createCode(data: Omit<Code, "id" | "deletedAt">) {
return prisma.code.create({ data }); return prisma.code.create({ data });
} }
export async function getCodeByCode(code: string) { export async function getUnexpiredCodeByCode(code: string) {
return prisma.code.findUnique({ const now = new Date();
where: { code }, return prisma.code.findFirst({
where: {
code,
expiresAt: { gt: now },
deletedAt: null,
},
include: { client: true, user: true }, include: { client: true, user: true },
}); });
} }
export async function deleteCode(code: string) { export async function deleteCode(code: string) {
await prisma.code.delete({ where: { code } }); await prisma.code.update({
where: { code },
data: { deletedAt: new Date() },
});
} }

View File

@ -0,0 +1,36 @@
import "server-only";
import WordArray from "crypto-js/lib-typedarrays";
import { getClientByClientId } from "@/lib/dto/client";
import { createCode } from "@/lib/dto/code";
export async function getAuthorizeUrl(params: URLSearchParams) {
// client
const client = await getClientByClientId(params.get("client_id") as string);
if (!client) {
throw new Error("Client Id invalid (code: -1004).");
}
// redirect url
const redirect_uri = new URL(client.redirectUri);
if (params.has("state")) {
redirect_uri.searchParams.append("state", params.get("state") || "");
}
const code = WordArray.random(32).toString();
redirect_uri.searchParams.append("code", code);
// storage code
try {
await createCode({
code,
expiresAt: new Date(Date.now() + 10 * 60 * 1000),
clientId: client.id,
userId: client.userId,
});
} catch {
throw new Error("Create code error (code: -1005).");
}
return redirect_uri.toString();
}