fix: make invoice activation race-safe

This commit is contained in:
sirily
2026-03-10 17:51:33 +03:00
parent 336cb7f33e
commit 729f1af3c8
2 changed files with 124 additions and 15 deletions

View File

@@ -24,12 +24,12 @@ test("markInvoicePaid activates a pending invoice once and writes an admin audit
assert.equal(result.status, "paid"); assert.equal(result.status, "paid");
assert.ok(result.paidAt instanceof Date); assert.ok(result.paidAt instanceof Date);
assert.equal(database.calls.paymentInvoiceUpdate.length, 1); assert.equal(database.calls.paymentInvoiceUpdateMany.length, 1);
assert.equal(database.calls.subscriptionUpdate.length, 1); assert.equal(database.calls.subscriptionUpdate.length, 1);
assert.equal(database.calls.usageLedgerCreate.length, 1); assert.equal(database.calls.usageLedgerCreate.length, 1);
assert.equal(database.calls.adminAuditCreate.length, 1); assert.equal(database.calls.adminAuditCreate.length, 1);
const paymentUpdate = database.calls.paymentInvoiceUpdate[0] as ({ const paymentUpdate = database.calls.paymentInvoiceUpdateMany[0] as ({
data: { status: "paid"; paidAt: Date }; data: { status: "paid"; paidAt: Date };
} | undefined); } | undefined);
const auditEntry = database.calls.adminAuditCreate[0]; const auditEntry = database.calls.adminAuditCreate[0];
@@ -75,7 +75,7 @@ test("markInvoicePaid is idempotent for already paid invoices", async () => {
assert.equal(result.status, "paid"); assert.equal(result.status, "paid");
assert.equal(result.paidAt?.toISOString(), paidAt.toISOString()); assert.equal(result.paidAt?.toISOString(), paidAt.toISOString());
assert.equal(database.calls.paymentInvoiceUpdate.length, 0); assert.equal(database.calls.paymentInvoiceUpdateMany.length, 0);
assert.equal(database.calls.subscriptionUpdate.length, 0); assert.equal(database.calls.subscriptionUpdate.length, 0);
assert.equal(database.calls.usageLedgerCreate.length, 0); assert.equal(database.calls.usageLedgerCreate.length, 0);
assert.equal(database.calls.adminAuditCreate.length, 1); assert.equal(database.calls.adminAuditCreate.length, 1);
@@ -108,36 +108,100 @@ test("markInvoicePaid rejects invalid terminal invoice transitions", async () =>
error.message === 'Invoice in status "expired" cannot be marked paid.', error.message === 'Invoice in status "expired" cannot be marked paid.',
); );
assert.equal(database.calls.paymentInvoiceUpdate.length, 0); assert.equal(database.calls.paymentInvoiceUpdateMany.length, 0);
assert.equal(database.calls.subscriptionUpdate.length, 0); assert.equal(database.calls.subscriptionUpdate.length, 0);
assert.equal(database.calls.usageLedgerCreate.length, 0); assert.equal(database.calls.usageLedgerCreate.length, 0);
assert.equal(database.calls.adminAuditCreate.length, 0); assert.equal(database.calls.adminAuditCreate.length, 0);
}); });
test("markInvoicePaid treats a concurrent pending->paid race as a replay without duplicate side effects", async () => {
const paidAt = new Date("2026-03-10T12:00:00.000Z");
const database = createBillingDatabase({
invoice: createInvoiceFixture({
status: "pending",
paidAt: null,
subscription: createSubscriptionFixture(),
}),
updateManyCount: 0,
invoiceAfterFailedTransition: createInvoiceFixture({
status: "paid",
paidAt,
subscription: createSubscriptionFixture(),
}),
});
const store = createPrismaBillingStore(database.client);
const result = await store.markInvoicePaid({
invoiceId: "invoice_1",
actor: {
type: "web_admin",
ref: "admin_user_1",
},
});
assert.equal(result.status, "paid");
assert.equal(result.paidAt?.toISOString(), paidAt.toISOString());
assert.equal(database.calls.paymentInvoiceUpdateMany.length, 1);
assert.equal(database.calls.subscriptionUpdate.length, 0);
assert.equal(database.calls.usageLedgerCreate.length, 0);
assert.equal(database.calls.adminAuditCreate.length, 1);
assert.equal(database.calls.adminAuditCreate[0]?.action, "invoice_mark_paid_replayed");
assert.equal(database.calls.adminAuditCreate[0]?.metadata?.replayed, true);
});
function createBillingDatabase(input: { function createBillingDatabase(input: {
invoice: ReturnType<typeof createInvoiceFixture>; invoice: ReturnType<typeof createInvoiceFixture>;
updateManyCount?: number;
invoiceAfterFailedTransition?: ReturnType<typeof createInvoiceFixture>;
}) { }) {
const calls = { const calls = {
paymentInvoiceUpdate: [] as Array<Record<string, unknown>>, paymentInvoiceUpdateMany: [] as Array<Record<string, unknown>>,
subscriptionUpdate: [] as Array<Record<string, unknown>>, subscriptionUpdate: [] as Array<Record<string, unknown>>,
usageLedgerCreate: [] as Array<Record<string, unknown>>, usageLedgerCreate: [] as Array<Record<string, unknown>>,
adminAuditCreate: [] as Array<Record<string, any>>, adminAuditCreate: [] as Array<Record<string, any>>,
}; };
let currentInvoice = input.invoice; let currentInvoice = input.invoice;
let findUniqueCallCount = 0;
const transaction = { const transaction = {
paymentInvoice: { paymentInvoice: {
findUnique: async () => currentInvoice, findUnique: async () => {
update: async ({ data }: { data: { status: "paid"; paidAt: Date } }) => { findUniqueCallCount += 1;
calls.paymentInvoiceUpdate.push({ data });
currentInvoice = { if (
...currentInvoice, input.invoiceAfterFailedTransition &&
status: data.status, input.updateManyCount === 0 &&
paidAt: data.paidAt, findUniqueCallCount > 1
}; ) {
currentInvoice = input.invoiceAfterFailedTransition;
}
return currentInvoice; return currentInvoice;
}, },
updateMany: async ({
where,
data,
}: {
where: { id: string; status: "pending" };
data: { status: "paid"; paidAt: Date };
}) => {
calls.paymentInvoiceUpdateMany.push({ where, data });
const count =
input.updateManyCount ??
(currentInvoice.id === where.id && currentInvoice.status === where.status ? 1 : 0);
if (count > 0) {
currentInvoice = {
...currentInvoice,
status: data.status,
paidAt: data.paidAt,
};
}
return { count };
},
}, },
subscription: { subscription: {
update: async ({ data }: { data: Record<string, unknown> }) => { update: async ({ data }: { data: Record<string, unknown> }) => {

View File

@@ -170,14 +170,59 @@ export function createPrismaBillingStore(database: PrismaClient = defaultPrisma)
} }
const paidAt = invoice.paidAt ?? new Date(); const paidAt = invoice.paidAt ?? new Date();
const updatedInvoice = await transaction.paymentInvoice.update({ const transitionResult = await transaction.paymentInvoice.updateMany({
where: { id: invoice.id }, where: {
id: invoice.id,
status: "pending",
},
data: { data: {
status: "paid", status: "paid",
paidAt, paidAt,
}, },
}); });
if (transitionResult.count === 0) {
const currentInvoice = await transaction.paymentInvoice.findUnique({
where: { id: input.invoiceId },
include: {
subscription: {
include: {
plan: true,
},
},
},
});
if (!currentInvoice) {
throw new BillingError("invoice_not_found", "Invoice not found.");
}
if (currentInvoice.status === "paid") {
await writeInvoicePaidAuditLog(transaction, currentInvoice, input.actor, true);
return mapInvoice(currentInvoice);
}
throw new BillingError(
"invoice_transition_not_allowed",
`Invoice in status "${currentInvoice.status}" cannot be marked paid.`,
);
}
const updatedInvoice = await transaction.paymentInvoice.findUnique({
where: { id: invoice.id },
include: {
subscription: {
include: {
plan: true,
},
},
},
});
if (!updatedInvoice) {
throw new BillingError("invoice_not_found", "Invoice not found.");
}
if (invoice.subscription) { if (invoice.subscription) {
const periodStart = paidAt; const periodStart = paidAt;
const periodEnd = addDays(periodStart, 30); const periodEnd = addDays(periodStart, 30);