feat: Add user access control and client management enhancements

- Introduced `allowedUsers` field to Client model for granular access control
- Implemented user filtering in authorization process
- Updated client edit form with allowed users configuration
- Enhanced dashboard and admin pages with improved user and client management
- Refactored client update and delete API routes
- Added form validation using Zod and react-hook-form
This commit is contained in:
wood chen 2025-02-20 01:49:52 +08:00
parent a81fee3f9a
commit 70e66294e3
15 changed files with 587 additions and 214 deletions

View File

@ -17,14 +17,15 @@
},
"dependencies": {
"@auth/prisma-adapter": "^2.4.2",
"@hookform/resolvers": "^4.1.0",
"@prisma/client": "^5.19.0",
"@radix-ui/react-alert-dialog": "^1.1.6",
"@radix-ui/react-avatar": "^1.1.0",
"@radix-ui/react-checkbox": "^1.1.1",
"@radix-ui/react-dialog": "^1.1.1",
"@radix-ui/react-dropdown-menu": "^2.1.1",
"@radix-ui/react-label": "^2.1.0",
"@radix-ui/react-slot": "^1.1.0",
"@radix-ui/react-label": "^2.1.2",
"@radix-ui/react-slot": "^1.1.2",
"@radix-ui/react-switch": "^1.1.0",
"@radix-ui/react-toast": "^1.2.1",
"class-variance-authority": "^0.7.0",
@ -37,8 +38,10 @@
"next-themes": "^0.3.0",
"react": "^18",
"react-dom": "^18",
"react-hook-form": "^7.54.2",
"tailwind-merge": "^2.5.2",
"tailwindcss-animate": "^1.0.7"
"tailwindcss-animate": "^1.0.7",
"zod": "^3.24.2"
},
"devDependencies": {
"@ianvs/prettier-plugin-sort-imports": "^4.3.1",

35
pnpm-lock.yaml generated
View File

@ -8,6 +8,9 @@ dependencies:
'@auth/prisma-adapter':
specifier: ^2.4.2
version: 2.7.4(@prisma/client@5.22.0)
'@hookform/resolvers':
specifier: ^4.1.0
version: 4.1.0(react-hook-form@7.54.2)
'@prisma/client':
specifier: ^5.19.0
version: 5.22.0(prisma@5.22.0)
@ -27,10 +30,10 @@ dependencies:
specifier: ^2.1.1
version: 2.1.6(@types/react-dom@18.3.5)(@types/react@18.3.18)(react-dom@18.3.1)(react@18.3.1)
'@radix-ui/react-label':
specifier: ^2.1.0
specifier: ^2.1.2
version: 2.1.2(@types/react-dom@18.3.5)(@types/react@18.3.18)(react-dom@18.3.1)(react@18.3.1)
'@radix-ui/react-slot':
specifier: ^1.1.0
specifier: ^1.1.2
version: 1.1.2(@types/react@18.3.18)(react@18.3.1)
'@radix-ui/react-switch':
specifier: ^1.1.0
@ -68,12 +71,18 @@ dependencies:
react-dom:
specifier: ^18
version: 18.3.1(react@18.3.1)
react-hook-form:
specifier: ^7.54.2
version: 7.54.2(react@18.3.1)
tailwind-merge:
specifier: ^2.5.2
version: 2.6.0
tailwindcss-animate:
specifier: ^1.0.7
version: 1.0.7(tailwindcss@3.4.17)
zod:
specifier: ^3.24.2
version: 3.24.2
devDependencies:
'@ianvs/prettier-plugin-sort-imports':
@ -320,6 +329,15 @@ packages:
resolution: {integrity: sha512-MDWhGtE+eHw5JW7lq4qhc5yRLS11ERl1c7Z6Xd0a58DozHES6EnNNwUWbMiG4J9Cgj053Bhk8zvlhFYKVhULwg==}
dev: false
/@hookform/resolvers@4.1.0(react-hook-form@7.54.2):
resolution: {integrity: sha512-fX/uHKb+OOCpACLc6enuTQsf0ZpRrKbeBBPETg5PCPLCIYV6osP2Bw6ezuclM61lH+wBF9eXcuC0+BFh9XOEnQ==}
peerDependencies:
react-hook-form: ^7.0.0
dependencies:
caniuse-lite: 1.0.30001700
react-hook-form: 7.54.2(react@18.3.1)
dev: false
/@humanwhocodes/config-array@0.13.0:
resolution: {integrity: sha512-DZLEEqFWQFiyK6h5YIeynKx7JlvCYWL0cImfSRXZ9l4Sg2efkFGTuFf6vzXjK1cq6IYkU+Eg/JizXw+TD2vRNw==}
engines: {node: '>=10.10.0'}
@ -3645,6 +3663,15 @@ packages:
scheduler: 0.23.2
dev: false
/react-hook-form@7.54.2(react@18.3.1):
resolution: {integrity: sha512-eHpAUgUjWbZocoQYUHposymRb4ZP6d0uwUnooL2uOybA9/3tPUvoAKqEWK1WaSiTxxOfTpffNZP7QwlnM3/gEg==}
engines: {node: '>=18.0.0'}
peerDependencies:
react: ^16.8.0 || ^17 || ^18 || ^19
dependencies:
react: 18.3.1
dev: false
/react-is@16.13.1:
resolution: {integrity: sha512-24e6ynE2H+OKt4kqsOvNd8kBpV65zoxbA4BVsEOB3ARVWQki/DHzaUoC5KuON/BiccDaCCTZBuOcfZs70kR8bQ==}
dev: true
@ -4449,3 +4476,7 @@ packages:
resolution: {integrity: sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q==}
engines: {node: '>=10'}
dev: true
/zod@3.24.2:
resolution: {integrity: sha512-lY7CDW43ECgW9u1TcT3IoXHflywfVqDYze4waEz812jR/bZ8FHDsl7pFQoSZTz5N+2NqRXs8GBwnAwo3ZNxqhQ==}
dev: false

View File

@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "clients" ADD COLUMN "allowedUsers" TEXT[] DEFAULT ARRAY[]::TEXT[];

View File

@ -44,6 +44,7 @@ model Client {
logo String
description String?
enabled Boolean @default(true)
allowedUsers String[] @default([])
clientId String @unique
clientSecret String

View File

@ -2,6 +2,7 @@
import { createAuthorization } from "@/lib/dto/authorization";
import { getAuthorizeUrl } from "@/lib/oauth/authorize-url";
import { prisma } from "@/lib/prisma";
export async function handleAuthorizeAction(
oauth: string,
@ -9,6 +10,28 @@ export async function handleAuthorizeAction(
clientId: string,
scope: string,
) {
// 检查客户端是否限制了允许的用户
const client = await prisma.client.findUnique({
where: { id: clientId },
select: { allowedUsers: true },
});
if (!client) {
throw new Error("应用不存在");
}
// 如果设置了允许的用户列表,检查当前用户是否在列表中
if (client.allowedUsers.length > 0) {
const user = await prisma.user.findUnique({
where: { id: userId },
select: { username: true },
});
if (!user || !client.allowedUsers.includes(user.username)) {
throw new Error("您没有权限使用此应用");
}
}
const oauthParams = new URLSearchParams(atob(oauth));
const redirectUrl = getAuthorizeUrl(oauthParams);

View File

@ -1,3 +1,4 @@
import Image from "next/image";
import Link from "next/link";
import { redirect } from "next/navigation";
import { Search } from "lucide-react";
@ -22,6 +23,7 @@ import {
TableHeader,
TableRow,
} from "@/components/ui/table";
import { ClientStatusToggle } from "@/components/admin/client-status-toggle";
async function getClients(search?: string) {
const where = search
@ -107,10 +109,13 @@ export default async function ClientsPage({
<TableCell className="font-medium">
<div className="flex items-center gap-2">
{client.logo && (
<img
<Image
src={client.logo}
alt={client.name}
className="h-6 w-6 rounded-full"
width={24}
height={24}
className="rounded-full"
unoptimized
/>
)}
{client.name}
@ -130,11 +135,14 @@ export default async function ClientsPage({
</Badge>
</TableCell>
<TableCell>
<div className="flex items-center justify-end gap-2">
<ClientStatusToggle client={client} />
<Link href={`/admin/clients/${client.id}`}>
<Button variant="outline" size="sm">
</Button>
</Link>
</div>
</TableCell>
</TableRow>
))}

View File

@ -20,27 +20,6 @@ import {
TableRow,
} from "@/components/ui/table";
async function getUsers(search?: string) {
const where = search
? {
OR: [
{ username: { contains: search } },
{ email: { contains: search } },
{ name: { contains: search } },
],
}
: {};
const users = await prisma.user.findMany({
where,
orderBy: {
createdAt: "desc",
},
});
return users;
}
export default async function UsersPage({
searchParams,
}: {
@ -51,7 +30,13 @@ export default async function UsersPage({
redirect("/dashboard");
}
const users = await getUsers(searchParams.search);
const search = searchParams.search || "";
const users = await prisma.user.findMany({
where: {
OR: [{ name: { contains: search } }, { email: { contains: search } }],
},
orderBy: { createdAt: "desc" },
});
return (
<div className="mx-auto max-w-7xl px-4 py-8 sm:px-6 lg:px-8">
@ -59,8 +44,8 @@ export default async function UsersPage({
<CardHeader>
<div className="flex items-center justify-between">
<div>
<CardTitle></CardTitle>
<CardDescription></CardDescription>
<CardTitle></CardTitle>
<CardDescription></CardDescription>
</div>
<div className="relative w-64">
<Search className="absolute left-2 top-2.5 h-4 w-4 text-muted-foreground" />
@ -79,9 +64,9 @@ export default async function UsersPage({
<Table>
<TableHeader>
<TableRow>
<TableHead>ID</TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
@ -90,15 +75,11 @@ export default async function UsersPage({
<TableBody>
{users.map((user) => (
<TableRow key={user.id}>
<TableCell className="font-medium">{user.username}</TableCell>
<TableCell>{user.id}</TableCell>
<TableCell>{user.name}</TableCell>
<TableCell>{user.email}</TableCell>
<TableCell>{user.name || "-"}</TableCell>
<TableCell>
{user.role === "ADMIN"
? "管理员"
: user.moderator
? "版主"
: "用户"}
{user.role === "ADMIN" ? "管理员" : "用户"}
</TableCell>
<TableCell>
{new Date(user.createdAt).toLocaleString()}

View File

@ -1,6 +1,11 @@
import Link from "next/link";
import { AppWindow, Settings } from "lucide-react";
import { redirect } from "next/navigation";
import type { Client } from "@prisma/client";
import { AppWindow, Users } from "lucide-react";
import { getAuthorizationStats } from "@/lib/dto/authorization";
import { prisma } from "@/lib/prisma";
import { getCurrentUser } from "@/lib/session";
import { Button } from "@/components/ui/button";
import {
Card,
@ -10,45 +15,120 @@ import {
CardTitle,
} from "@/components/ui/card";
export default function DashboardPage() {
interface ClientWithStats extends Client {
stats: {
total: number;
activeLastMonth: number;
newLastMonth: number;
};
}
async function getUserClients(userId: string): Promise<ClientWithStats[]> {
try {
const clients = await prisma.client.findMany({
where: { userId },
orderBy: { createdAt: "desc" },
});
// 获取每个应用的授权统计
const clientsWithStats = await Promise.all(
clients.map(async (client) => {
try {
const stats = await getAuthorizationStats(client.id);
return {
...client,
stats,
};
} catch (error) {
console.error(`获取应用 ${client.name} 的统计信息失败:`, error);
return {
...client,
stats: {
total: 0,
activeLastMonth: 0,
newLastMonth: 0,
},
};
}
}),
);
return clientsWithStats;
} catch (error) {
console.error("获取用户应用列表失败:", error);
return [];
}
}
export default async function DashboardPage() {
const user = await getCurrentUser();
if (!user?.id) {
redirect("/sign-in");
}
const clients = await getUserClients(user.id);
return (
<div className="mx-auto max-w-7xl px-4 py-8 sm:px-6 lg:px-8">
<div className="grid gap-4 md:grid-cols-2 lg:grid-cols-3">
<Card>
<CardHeader>
<CardTitle className="flex items-center gap-2">
<AppWindow className="h-6 w-6" />
</CardTitle>
<CardDescription>
OAuth
</CardDescription>
</CardHeader>
<CardContent>
<Link href="/dashboard/clients">
<Button className="w-full"></Button>
<div className="mb-8 flex items-center justify-between">
<h2 className="text-lg font-medium"></h2>
<Link href="/dashboard/clients/new">
<Button></Button>
</Link>
</CardContent>
</Card>
</div>
<Card>
<div className="grid gap-6 sm:grid-cols-2 lg:grid-cols-3">
{clients.map((client) => (
<Card key={client.id}>
<CardHeader>
<CardTitle className="flex items-center gap-2">
<Settings className="h-6 w-6" />
<AppWindow className="h-5 w-5" />
{client.name}
</CardTitle>
<CardDescription>
{client.description || "暂无描述"}
</CardDescription>
</CardHeader>
<CardContent>
<Link href="/dashboard/settings">
<Button className="w-full" variant="outline">
<div className="space-y-4">
<div>
<div className="flex items-center gap-2">
<Users className="h-4 w-4 text-muted-foreground" />
<p className="text-sm text-muted-foreground"></p>
</div>
<p className="mt-1 text-2xl font-bold">
{client.stats.total}
</p>
<div className="mt-1 flex items-center gap-4 text-sm text-muted-foreground">
<span>30: {client.stats.activeLastMonth}</span>
<span>30: {client.stats.newLastMonth}</span>
</div>
</div>
<Link href={`/dashboard/clients/${client.id}`}>
<Button variant="outline" className="w-full">
</Button>
</Link>
</div>
</CardContent>
</Card>
))}
{clients.length === 0 && (
<Card className="sm:col-span-2 lg:col-span-3">
<CardHeader>
<CardTitle></CardTitle>
<CardDescription>
使 OAuth
</CardDescription>
</CardHeader>
<CardContent>
<Link href="/dashboard/clients/new">
<Button></Button>
</Link>
</CardContent>
</Card>
)}
</div>
</div>
);

View File

@ -1,68 +1,61 @@
import { NextRequest } from "next/server";
import { NextRequest, NextResponse } from "next/server";
import type { Client, Prisma } from "@prisma/client";
import { prisma } from "@/lib/prisma";
import { getCurrentUser } from "@/lib/session";
export async function PUT(
request: NextRequest,
export async function PATCH(
req: Request,
{ params }: { params: { id: string } },
) {
try {
const user = await getCurrentUser();
if (!user) {
return new Response("Unauthorized", { status: 401 });
return new NextResponse("Unauthorized", { status: 401 });
}
const data = await req.json();
const client = await prisma.client.findUnique({
where: { id: params.id },
});
if (!client) {
return new Response("Not Found", { status: 404 });
if (!client || client.userId !== user.id) {
return new NextResponse("Forbidden", { status: 403 });
}
if (client.userId !== user.id) {
return new Response("Forbidden", { status: 403 });
}
const updateData = {
name: data.name,
description: data.description,
home: data.home,
logo: data.logo,
redirectUri: data.redirectUri,
} satisfies Partial<Prisma.ClientUpdateInput>;
const formData = await request.formData();
const name = formData.get("name") as string;
const home = formData.get("home") as string;
const logo = formData.get("logo") as string;
const redirectUri = formData.get("redirectUri") as string;
const description = formData.get("description") as string;
// 验证必填字段
if (!name || !home || !logo || !redirectUri) {
return new Response("Missing required fields", { status: 400 });
// 单独处理 allowedUsers 字段
if (Array.isArray(data.allowedUsers)) {
await prisma.$executeRaw`UPDATE clients SET "allowedUsers" = ${data.allowedUsers}::text[] WHERE id = ${params.id}`;
}
const updatedClient = await prisma.client.update({
where: { id: params.id },
data: {
name,
home,
logo,
redirectUri,
description,
},
data: updateData,
});
return Response.json(updatedClient);
return NextResponse.json(updatedClient);
} catch (error) {
console.error("Error updating client:", error);
return new Response("Internal Server Error", { status: 500 });
console.error("[CLIENT_UPDATE]", error);
return new NextResponse("Internal Error", { status: 500 });
}
}
export async function DELETE(
_request: NextRequest,
_request: Request,
{ params }: { params: { id: string } },
) {
try {
const user = await getCurrentUser();
if (!user) {
return new Response("Unauthorized", { status: 401 });
return new NextResponse("Unauthorized", { status: 401 });
}
const client = await prisma.client.findUnique({
@ -70,11 +63,11 @@ export async function DELETE(
});
if (!client) {
return new Response("Not Found", { status: 404 });
return new NextResponse("Not Found", { status: 404 });
}
if (client.userId !== user.id) {
return new Response("Forbidden", { status: 403 });
return new NextResponse("Forbidden", { status: 403 });
}
// 删除相关的授权记录
@ -97,9 +90,9 @@ export async function DELETE(
where: { id: params.id },
});
return new Response(null, { status: 204 });
return new NextResponse(null, { status: 204 });
} catch (error) {
console.error("Error deleting client:", error);
return new Response("Internal Server Error", { status: 500 });
console.error("[CLIENT_DELETE]", error);
return new NextResponse("Internal Error", { status: 500 });
}
}

View File

@ -2,38 +2,70 @@
import { useState } from "react";
import { useRouter } from "next/navigation";
import { zodResolver } from "@hookform/resolvers/zod";
import type { Client } from "@prisma/client";
import { useForm } from "react-hook-form";
import * as z from "zod";
import { useToast } from "@/hooks/use-toast";
import { Button } from "@/components/ui/button";
import {
Card,
CardContent,
CardDescription,
CardHeader,
CardTitle,
} from "@/components/ui/card";
Form,
FormControl,
FormDescription,
FormField,
FormItem,
FormLabel,
FormMessage,
} from "@/components/ui/form";
import { Input } from "@/components/ui/input";
import { Label } from "@/components/ui/label";
import { Textarea } from "@/components/ui/textarea";
const formSchema = z.object({
name: z.string().min(1, "应用名称不能为空"),
description: z.string().optional(),
home: z.string().url("请输入有效的URL"),
logo: z.string().url("请输入有效的URL"),
redirectUri: z.string().url("请输入有效的URL"),
allowedUsers: z.string().optional(),
});
interface EditClientFormProps {
client: Client;
}
export function EditClientForm({ client }: EditClientFormProps) {
const [isLoading, setIsLoading] = useState(false);
const { toast } = useToast();
const router = useRouter();
const [isLoading, setIsLoading] = useState(false);
async function onSubmit(event: React.FormEvent<HTMLFormElement>) {
event.preventDefault();
setIsLoading(true);
const form = useForm<z.infer<typeof formSchema>>({
resolver: zodResolver(formSchema),
defaultValues: {
name: client.name,
description: client.description || "",
home: client.home,
logo: client.logo,
redirectUri: client.redirectUri,
allowedUsers: client.allowedUsers?.join(", ") || "",
},
});
async function onSubmit(values: z.infer<typeof formSchema>) {
try {
const formData = new FormData(event.currentTarget);
setIsLoading(true);
const response = await fetch(`/api/clients/${client.id}`, {
method: "PUT",
body: formData,
method: "PATCH",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
...values,
allowedUsers: values.allowedUsers
? values.allowedUsers
.split(",")
.map((u) => u.trim())
.filter(Boolean)
: [],
}),
});
if (!response.ok) {
@ -41,90 +73,108 @@ export function EditClientForm({ client }: EditClientFormProps) {
}
router.refresh();
toast({
title: "更新成功",
description: "应用信息已更新",
});
router.push("/dashboard/clients");
} catch (error) {
toast({
variant: "destructive",
title: "更新失败",
description: error instanceof Error ? error.message : "未知错误",
});
console.error("Error updating client:", error);
} finally {
setIsLoading(false);
}
}
return (
<Card>
<CardHeader>
<CardTitle></CardTitle>
<CardDescription></CardDescription>
</CardHeader>
<CardContent>
<form onSubmit={onSubmit} className="space-y-4">
<div className="grid gap-2">
<Label htmlFor="name"></Label>
<Input
id="name"
<Form {...form}>
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-8">
<FormField
control={form.control}
name="name"
defaultValue={client.name}
disabled={isLoading}
render={({ field }) => (
<FormItem>
<FormLabel></FormLabel>
<FormControl>
<Input {...field} />
</FormControl>
<FormMessage />
</FormItem>
)}
/>
</div>
<div className="grid gap-2">
<Label htmlFor="home"></Label>
<Input
id="home"
name="home"
defaultValue={client.home}
disabled={isLoading}
/>
</div>
<div className="grid gap-2">
<Label htmlFor="logo"></Label>
<Input
id="logo"
name="logo"
defaultValue={client.logo}
disabled={isLoading}
/>
</div>
<div className="grid gap-2">
<Label htmlFor="redirectUri"></Label>
<Input
id="redirectUri"
name="redirectUri"
defaultValue={client.redirectUri}
disabled={isLoading}
/>
</div>
<div className="grid gap-2">
<Label htmlFor="description"></Label>
<Input
id="description"
<FormField
control={form.control}
name="description"
defaultValue={client.description || ""}
disabled={isLoading}
render={({ field }) => (
<FormItem>
<FormLabel></FormLabel>
<FormControl>
<Textarea {...field} />
</FormControl>
<FormMessage />
</FormItem>
)}
/>
</div>
<div className="flex justify-end space-x-4">
<Button
type="button"
variant="outline"
onClick={() => router.push("/dashboard/clients")}
disabled={isLoading}
>
</Button>
<FormField
control={form.control}
name="home"
render={({ field }) => (
<FormItem>
<FormLabel></FormLabel>
<FormControl>
<Input {...field} />
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
name="logo"
render={({ field }) => (
<FormItem>
<FormLabel></FormLabel>
<FormControl>
<Input {...field} />
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
name="redirectUri"
render={({ field }) => (
<FormItem>
<FormLabel></FormLabel>
<FormControl>
<Input {...field} />
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
name="allowedUsers"
render={({ field }) => (
<FormItem>
<FormLabel></FormLabel>
<FormControl>
<Input {...field} placeholder="用户名列表,用逗号分隔" />
</FormControl>
<FormDescription>
Q58
</FormDescription>
<FormMessage />
</FormItem>
)}
/>
<Button type="submit" disabled={isLoading}>
{isLoading ? "保存中..." : "保存"}
</Button>
</div>
</form>
</CardContent>
</Card>
</Form>
);
}

View File

@ -12,7 +12,7 @@ export function DashboardHeader() {
if (pathname === "/dashboard/clients") return "应用管理";
if (pathname.includes("/dashboard/clients/")) return "应用详情";
if (pathname === "/dashboard/settings") return "账号设置";
if (pathname === "/admin/users") return "用户管理";
if (pathname === "/admin/users") return "用户列表";
if (pathname === "/admin/logs") return "系统日志";
return "";
};

View File

@ -118,7 +118,7 @@ export function NavBar() {
pathname === "/admin/users" && "bg-accent",
)}
>
</Link>
</DropdownMenuItem>
<DropdownMenuItem asChild>

View File

@ -5,7 +5,7 @@ import { cva, type VariantProps } from "class-variance-authority";
import { cn } from "@/lib/utils";
const buttonVariants = cva(
"inline-flex items-center justify-center whitespace-nowrap rounded-md text-sm font-medium ring-offset-background transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50",
"inline-flex items-center justify-center gap-2 whitespace-nowrap rounded-md text-sm font-medium ring-offset-background transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50 [&_svg]:pointer-events-none [&_svg]:size-4 [&_svg]:shrink-0",
{
variants: {
variant: {

179
src/components/ui/form.tsx Normal file
View File

@ -0,0 +1,179 @@
"use client";
import * as React from "react";
import * as LabelPrimitive from "@radix-ui/react-label";
import { Slot } from "@radix-ui/react-slot";
import {
Controller,
ControllerProps,
FieldPath,
FieldValues,
FormProvider,
useFormContext,
} from "react-hook-form";
import { cn } from "@/lib/utils";
import { Label } from "@/components/ui/label";
const Form = FormProvider;
type FormFieldContextValue<
TFieldValues extends FieldValues = FieldValues,
TName extends FieldPath<TFieldValues> = FieldPath<TFieldValues>,
> = {
name: TName;
};
const FormFieldContext = React.createContext<FormFieldContextValue>(
{} as FormFieldContextValue,
);
const FormField = <
TFieldValues extends FieldValues = FieldValues,
TName extends FieldPath<TFieldValues> = FieldPath<TFieldValues>,
>({
...props
}: ControllerProps<TFieldValues, TName>) => {
return (
<FormFieldContext.Provider value={{ name: props.name }}>
<Controller {...props} />
</FormFieldContext.Provider>
);
};
const useFormField = () => {
const fieldContext = React.useContext(FormFieldContext);
const itemContext = React.useContext(FormItemContext);
const { getFieldState, formState } = useFormContext();
const fieldState = getFieldState(fieldContext.name, formState);
if (!fieldContext) {
throw new Error("useFormField should be used within <FormField>");
}
const { id } = itemContext;
return {
id,
name: fieldContext.name,
formItemId: `${id}-form-item`,
formDescriptionId: `${id}-form-item-description`,
formMessageId: `${id}-form-item-message`,
...fieldState,
};
};
type FormItemContextValue = {
id: string;
};
const FormItemContext = React.createContext<FormItemContextValue>(
{} as FormItemContextValue,
);
const FormItem = React.forwardRef<
HTMLDivElement,
React.HTMLAttributes<HTMLDivElement>
>(({ className, ...props }, ref) => {
const id = React.useId();
return (
<FormItemContext.Provider value={{ id }}>
<div ref={ref} className={cn("space-y-2", className)} {...props} />
</FormItemContext.Provider>
);
});
FormItem.displayName = "FormItem";
const FormLabel = React.forwardRef<
React.ElementRef<typeof LabelPrimitive.Root>,
React.ComponentPropsWithoutRef<typeof LabelPrimitive.Root>
>(({ className, ...props }, ref) => {
const { error, formItemId } = useFormField();
return (
<Label
ref={ref}
className={cn(error && "text-destructive", className)}
htmlFor={formItemId}
{...props}
/>
);
});
FormLabel.displayName = "FormLabel";
const FormControl = React.forwardRef<
React.ElementRef<typeof Slot>,
React.ComponentPropsWithoutRef<typeof Slot>
>(({ ...props }, ref) => {
const { error, formItemId, formDescriptionId, formMessageId } =
useFormField();
return (
<Slot
ref={ref}
id={formItemId}
aria-describedby={
!error
? `${formDescriptionId}`
: `${formDescriptionId} ${formMessageId}`
}
aria-invalid={!!error}
{...props}
/>
);
});
FormControl.displayName = "FormControl";
const FormDescription = React.forwardRef<
HTMLParagraphElement,
React.HTMLAttributes<HTMLParagraphElement>
>(({ className, ...props }, ref) => {
const { formDescriptionId } = useFormField();
return (
<p
ref={ref}
id={formDescriptionId}
className={cn("text-sm text-muted-foreground", className)}
{...props}
/>
);
});
FormDescription.displayName = "FormDescription";
const FormMessage = React.forwardRef<
HTMLParagraphElement,
React.HTMLAttributes<HTMLParagraphElement>
>(({ className, children, ...props }, ref) => {
const { error, formMessageId } = useFormField();
const body = error ? String(error?.message) : children;
if (!body) {
return null;
}
return (
<p
ref={ref}
id={formMessageId}
className={cn("text-sm font-medium text-destructive", className)}
{...props}
>
{body}
</p>
);
});
FormMessage.displayName = "FormMessage";
export {
useFormField,
Form,
FormItem,
FormLabel,
FormControl,
FormDescription,
FormMessage,
FormField,
};

View File

@ -0,0 +1,22 @@
import * as React from "react";
import { cn } from "@/lib/utils";
const Textarea = React.forwardRef<
HTMLTextAreaElement,
React.ComponentProps<"textarea">
>(({ className, ...props }, ref) => {
return (
<textarea
className={cn(
"flex min-h-[80px] w-full rounded-md border border-input bg-background px-3 py-2 text-base ring-offset-background placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-50 md:text-sm",
className,
)}
ref={ref}
{...props}
/>
);
});
Textarea.displayName = "Textarea";
export { Textarea };