229 lines
6.9 KiB
TypeScript
229 lines
6.9 KiB
TypeScript
import {
|
|
type ActiveSubscriptionContext,
|
|
type CreateGenerationRequestInput,
|
|
type CreateGenerationRequestDeps,
|
|
type GenerationRequestRecord,
|
|
type MarkGenerationSucceededDeps,
|
|
type SuccessfulGenerationRecord,
|
|
} from "@nproxy/domain";
|
|
import { Prisma, type PrismaClient } from "@prisma/client";
|
|
import { prisma as defaultPrisma } from "./prisma-client.js";
|
|
import { reconcileElapsedSubscription } from "./subscription-lifecycle.js";
|
|
|
|
export interface GenerationStore
|
|
extends CreateGenerationRequestDeps,
|
|
MarkGenerationSucceededDeps {}
|
|
|
|
export function createPrismaGenerationStore(
|
|
database: PrismaClient = defaultPrisma,
|
|
): GenerationStore {
|
|
return {
|
|
async findReusableRequest(userId: string, idempotencyKey: string) {
|
|
const request = await database.generationRequest.findFirst({
|
|
where: {
|
|
userId,
|
|
idempotencyKey,
|
|
},
|
|
});
|
|
|
|
return request ? mapGenerationRequest(request) : null;
|
|
},
|
|
|
|
async findActiveSubscriptionContext(
|
|
userId: string,
|
|
): Promise<ActiveSubscriptionContext | null> {
|
|
const subscription = await database.subscription.findFirst({
|
|
where: {
|
|
userId,
|
|
status: "active",
|
|
},
|
|
include: {
|
|
plan: true,
|
|
},
|
|
orderBy: [
|
|
{ currentPeriodEnd: "desc" },
|
|
{ createdAt: "desc" },
|
|
],
|
|
});
|
|
|
|
const currentSubscription = await reconcileElapsedSubscription(database, subscription, {
|
|
reload: async () =>
|
|
database.subscription.findFirst({
|
|
where: {
|
|
userId,
|
|
status: "active",
|
|
},
|
|
include: {
|
|
plan: true,
|
|
},
|
|
orderBy: [{ currentPeriodEnd: "desc" }, { createdAt: "desc" }],
|
|
}),
|
|
});
|
|
|
|
if (!currentSubscription || currentSubscription.status !== "active") {
|
|
return null;
|
|
}
|
|
|
|
const cycleStart =
|
|
currentSubscription.currentPeriodStart ??
|
|
currentSubscription.activatedAt ??
|
|
currentSubscription.createdAt;
|
|
|
|
const usageAggregation = await database.usageLedgerEntry.aggregate({
|
|
where: {
|
|
userId,
|
|
entryType: "generation_success",
|
|
createdAt: { gte: cycleStart },
|
|
},
|
|
_sum: {
|
|
deltaRequests: true,
|
|
},
|
|
});
|
|
|
|
return {
|
|
subscriptionId: currentSubscription.id,
|
|
planId: currentSubscription.planId,
|
|
monthlyRequestLimit: currentSubscription.plan.monthlyRequestLimit,
|
|
usedSuccessfulRequests: usageAggregation._sum.deltaRequests ?? 0,
|
|
};
|
|
},
|
|
|
|
async createGenerationRequest(
|
|
input: CreateGenerationRequestInput,
|
|
): Promise<GenerationRequestRecord> {
|
|
const request = await database.generationRequest.create({
|
|
data: {
|
|
userId: input.userId,
|
|
mode: input.mode,
|
|
providerModel: input.providerModel,
|
|
prompt: input.prompt.trim(),
|
|
resolutionPreset: input.resolutionPreset,
|
|
batchSize: input.batchSize,
|
|
...(input.sourceImageKey !== undefined
|
|
? { sourceImageKey: input.sourceImageKey }
|
|
: {}),
|
|
...(input.imageStrength !== undefined
|
|
? { imageStrength: new Prisma.Decimal(input.imageStrength) }
|
|
: {}),
|
|
...(input.idempotencyKey !== undefined
|
|
? { idempotencyKey: input.idempotencyKey }
|
|
: {}),
|
|
},
|
|
});
|
|
|
|
return mapGenerationRequest(request);
|
|
},
|
|
|
|
async getGenerationRequest(requestId: string): Promise<GenerationRequestRecord | null> {
|
|
const request = await database.generationRequest.findUnique({
|
|
where: {
|
|
id: requestId,
|
|
},
|
|
});
|
|
|
|
return request ? mapGenerationRequest(request) : null;
|
|
},
|
|
|
|
async markGenerationSucceeded(requestId: string): Promise<SuccessfulGenerationRecord> {
|
|
return database.$transaction(async (transaction) => {
|
|
const request = await transaction.generationRequest.findUnique({
|
|
where: {
|
|
id: requestId,
|
|
},
|
|
include: {
|
|
usageLedgerEntry: true,
|
|
},
|
|
});
|
|
|
|
if (!request) {
|
|
throw new Error(`Generation request ${requestId} was not found.`);
|
|
}
|
|
|
|
const completedAt = request.completedAt ?? new Date();
|
|
const nextStatus =
|
|
request.status === "succeeded" ? request.status : "succeeded";
|
|
|
|
const updatedRequest =
|
|
request.status === "succeeded" && request.completedAt
|
|
? request
|
|
: await transaction.generationRequest.update({
|
|
where: {
|
|
id: requestId,
|
|
},
|
|
data: {
|
|
status: nextStatus,
|
|
completedAt,
|
|
},
|
|
});
|
|
|
|
if (!request.usageLedgerEntry) {
|
|
await transaction.usageLedgerEntry.create({
|
|
data: {
|
|
userId: request.userId,
|
|
generationRequestId: request.id,
|
|
entryType: "generation_success",
|
|
deltaRequests: 1,
|
|
note: "Consumed after first successful generation result.",
|
|
},
|
|
});
|
|
}
|
|
|
|
return {
|
|
request: mapGenerationRequest(updatedRequest),
|
|
quotaConsumed: !request.usageLedgerEntry,
|
|
};
|
|
});
|
|
},
|
|
};
|
|
}
|
|
|
|
function mapGenerationRequest(
|
|
request: {
|
|
id: string;
|
|
userId: string;
|
|
mode: string;
|
|
status: string;
|
|
providerModel: string;
|
|
prompt: string;
|
|
sourceImageKey: string | null;
|
|
resolutionPreset: string;
|
|
batchSize: number;
|
|
imageStrength: Prisma.Decimal | null;
|
|
idempotencyKey: string | null;
|
|
terminalErrorCode: string | null;
|
|
terminalErrorText: string | null;
|
|
requestedAt: Date;
|
|
startedAt: Date | null;
|
|
completedAt: Date | null;
|
|
createdAt: Date;
|
|
updatedAt: Date;
|
|
},
|
|
): GenerationRequestRecord {
|
|
return {
|
|
id: request.id,
|
|
userId: request.userId,
|
|
mode: request.mode as GenerationRequestRecord["mode"],
|
|
status: request.status as GenerationRequestRecord["status"],
|
|
providerModel: request.providerModel,
|
|
prompt: request.prompt,
|
|
resolutionPreset: request.resolutionPreset,
|
|
batchSize: request.batchSize,
|
|
requestedAt: request.requestedAt,
|
|
createdAt: request.createdAt,
|
|
updatedAt: request.updatedAt,
|
|
...(request.sourceImageKey !== null ? { sourceImageKey: request.sourceImageKey } : {}),
|
|
...(request.imageStrength !== null
|
|
? { imageStrength: request.imageStrength.toNumber() }
|
|
: {}),
|
|
...(request.idempotencyKey !== null ? { idempotencyKey: request.idempotencyKey } : {}),
|
|
...(request.terminalErrorCode !== null
|
|
? { terminalErrorCode: request.terminalErrorCode }
|
|
: {}),
|
|
...(request.terminalErrorText !== null
|
|
? { terminalErrorText: request.terminalErrorText }
|
|
: {}),
|
|
...(request.startedAt !== null ? { startedAt: request.startedAt } : {}),
|
|
...(request.completedAt !== null ? { completedAt: request.completedAt } : {}),
|
|
};
|
|
}
|