Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 153 additions & 0 deletions web/src/app/api/v1/chat/completions/__tests__/completions.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
FREEBUFF_GLM_MODEL_ID,
isFreebuffDeploymentHours,
} from '@codebuff/common/constants/freebuff-models'
import { env } from '@codebuff/internal/env'
import { formatQuotaResetCountdown, postChatCompletions } from '../_post'
import {
checkFreeModeRateLimit,
Expand Down Expand Up @@ -1075,6 +1076,116 @@ describe('/api/v1/chat/completions POST endpoint', () => {
})

describe('Successful responses', () => {
const withCanopyWaveApiKey = async (testFn: () => Promise<void>) => {
const previousCanopyWaveApiKey = env.CANOPYWAVE_API_KEY
env.CANOPYWAVE_API_KEY = 'test'
try {
await testFn()
} finally {
env.CANOPYWAVE_API_KEY = previousCanopyWaveApiKey
}
}

const createCanopyWaveFallbackRequest = (stream: boolean) =>
new NextRequest('http://localhost:3000/api/v1/chat/completions', {
method: 'POST',
headers: { Authorization: 'Bearer test-api-key-123' },
body: JSON.stringify({
model: 'minimax/minimax-m2.5',
stream,
codebuff_metadata: {
run_id: 'run-123',
client_id: 'test-client-id-123',
client_request_id: 'test-client-session-id-123',
},
}),
})

const createCanopyWaveNoWorkersThenFireworksFetch = (stream: boolean) => {
const fetchedBodies: Record<string, unknown>[] = []
const fetch = mock(
async (_url: string | URL | Request, init?: RequestInit) => {
fetchedBodies.push(JSON.parse(init?.body as string))

if (fetchedBodies.length === 1) {
return Response.json(
{
error: {
message: 'No available workers',
code: 'no_available_workers',
},
},
{ status: 503 },
)
}

if (!stream) {
return Response.json({
id: 'test-id',
model: 'accounts/fireworks/models/minimax-m2p5',
choices: [{ message: { content: 'fireworks response' } }],
usage: {
prompt_tokens: 10,
completion_tokens: 20,
total_tokens: 30,
},
})
}

const encoder = new TextEncoder()
const fireworksStream = new ReadableStream({
start(controller) {
controller.enqueue(
encoder.encode(
'data: {"id":"test-id","model":"accounts/fireworks/models/minimax-m2p5","choices":[{"delta":{"content":"test"}}]}\n\n',
),
)
controller.enqueue(encoder.encode('data: [DONE]\n\n'))
controller.close()
},
})

return new Response(fireworksStream, {
status: 200,
headers: { 'Content-Type': 'text/event-stream' },
})
},
) as unknown as typeof globalThis.fetch

return { fetch, fetchedBodies }
}

const postCanopyWaveFallbackRequest = async ({
fetch,
stream,
}: {
fetch: typeof globalThis.fetch
stream: boolean
}) =>
postChatCompletions({
req: createCanopyWaveFallbackRequest(stream),
getUserInfoFromApiKey: mockGetUserInfoFromApiKey,
logger: mockLogger,
trackEvent: mockTrackEvent,
getUserUsageData: mockGetUserUsageData,
getAgentRunFromId: mockGetAgentRunFromId,
fetch,
insertMessageBigquery: mockInsertMessageBigquery,
loggerWithContext: mockLoggerWithContext,
checkSessionAdmissible: mockCheckSessionAdmissibleAllow,
})

const expectCanopyWaveThenFireworks = (
fetchedBodies: Record<string, unknown>[],
) => {
expect(fetchedBodies).toHaveLength(2)
expect(fetchedBodies[0].model).toBe('minimax/minimax-m2.5')
expect(fetchedBodies[1].model).toBe(
'accounts/fireworks/models/minimax-m2p5',
)
expect(mockLogger.warn).toHaveBeenCalled()
}

it('returns stream with correct headers', async () => {
const req = new NextRequest(
'http://localhost:3000/api/v1/chat/completions',
Expand Down Expand Up @@ -1158,6 +1269,48 @@ describe('/api/v1/chat/completions POST endpoint', () => {
},
FETCH_PATH_TEST_TIMEOUT_MS,
)

it(
'falls back to Fireworks when CanopyWave has no available workers for non-streaming requests',
async () => {
await withCanopyWaveApiKey(async () => {
const { fetch, fetchedBodies } =
createCanopyWaveNoWorkersThenFireworksFetch(false)
const response = await postCanopyWaveFallbackRequest({
fetch,
stream: false,
})

expect(response.status).toBe(200)
expectCanopyWaveThenFireworks(fetchedBodies)

const body = await response.json()
expect(body.model).toBe('minimax/minimax-m2.5')
expect(body.provider).toBe('Fireworks')
expect(body.choices[0].message.content).toBe('fireworks response')
})
},
FETCH_PATH_TEST_TIMEOUT_MS,
)

it(
'falls back to Fireworks when CanopyWave has no available workers for streaming requests',
async () => {
await withCanopyWaveApiKey(async () => {
const { fetch, fetchedBodies } =
createCanopyWaveNoWorkersThenFireworksFetch(true)
const response = await postCanopyWaveFallbackRequest({
fetch,
stream: true,
})

expect(response.status).toBe(200)
expect(response.headers.get('Content-Type')).toBe('text/event-stream')
expectCanopyWaveThenFireworks(fetchedBodies)
})
},
FETCH_PATH_TEST_TIMEOUT_MS,
)
})

describe('Subscription limit enforcement', () => {
Expand Down
94 changes: 74 additions & 20 deletions web/src/app/api/v1/chat/completions/_post.ts
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,50 @@ export const formatQuotaResetCountdown = (
return `in ${pluralize(minutes, 'minute')}`
}

type ProviderHandlerArgs = Parameters<typeof handleCanopyWaveStream>[0]
type ProviderHandler<T> = (args: ProviderHandlerArgs) => Promise<T>

function shouldFallbackCanopyWaveToFireworks(
error: unknown,
model: string,
): error is CanopyWaveError {
if (!(error instanceof CanopyWaveError) || !isFireworksModel(model)) {
return false
}
const message = error.errorBody.error.message.toLowerCase()
return (
error.statusCode === 429 ||
error.statusCode >= 500 ||
message.includes('no available workers')
)
}

async function handleCanopyWaveWithFireworksFallback<T>(
args: ProviderHandlerArgs,
handleCanopyWave: ProviderHandler<T>,
handleFireworks: ProviderHandler<T>,
): Promise<T> {
try {
return await handleCanopyWave(args)
} catch (error) {
if (!shouldFallbackCanopyWaveToFireworks(error, args.body.model)) {
throw error
}

args.logger.warn(
{
error: getErrorObject(error),
model: args.body.model,
providerStatusCode: error.statusCode,
providerStatusText: error.statusText,
},
'CanopyWave request failed, falling back to Fireworks',
)

return handleFireworks(args)
}
}

export type CheckSessionAdmissibleFn = typeof checkSessionAdmissible

type GateRejectCode = Extract<SessionGateResult, { ok: false }>['code']
Expand Down Expand Up @@ -599,7 +643,8 @@ export async function postChatCompletions(params: {
if (bodyStream) {
// Streaming request — route to SiliconFlow/CanopyWave/Fireworks for supported models
const useSiliconFlow = false // isSiliconFlowModel(typedBody.model)
const useCanopyWave = isCanopyWaveModel(typedBody.model)
const useCanopyWave =
!!env.CANOPYWAVE_API_KEY && isCanopyWaveModel(typedBody.model)
const useFireworks = !useCanopyWave && isFireworksModel(typedBody.model)
const useOpenAIDirect =
!useCanopyWave &&
Expand All @@ -616,15 +661,19 @@ export async function postChatCompletions(params: {
insertMessageBigquery,
})
: useCanopyWave
? await handleCanopyWaveStream({
body: typedBody,
userId,
stripeCustomerId,
agentId,
fetch,
logger,
insertMessageBigquery,
})
? await handleCanopyWaveWithFireworksFallback(
{
body: typedBody,
userId,
stripeCustomerId,
agentId,
fetch,
logger,
insertMessageBigquery,
},
handleCanopyWaveStream,
handleFireworksStream,
)
: useFireworks
? await handleFireworksStream({
body: typedBody,
Expand Down Expand Up @@ -678,7 +727,8 @@ export async function postChatCompletions(params: {
// Non-streaming request — route to SiliconFlow/CanopyWave/Fireworks for supported models
const model = typedBody.model
const useSiliconFlow = false // isSiliconFlowModel(model)
const useCanopyWave = isCanopyWaveModel(model)
const useCanopyWave =
!!env.CANOPYWAVE_API_KEY && isCanopyWaveModel(model)
const useFireworks = !useCanopyWave && isFireworksModel(model)
const shouldUseOpenAIEndpoint =
!useCanopyWave && !useFireworks && isOpenAIDirectModel(model)
Expand All @@ -694,15 +744,19 @@ export async function postChatCompletions(params: {
insertMessageBigquery,
})
: useCanopyWave
? handleCanopyWaveNonStream({
body: typedBody,
userId,
stripeCustomerId,
agentId,
fetch,
logger,
insertMessageBigquery,
})
? handleCanopyWaveWithFireworksFallback(
{
body: typedBody,
userId,
stripeCustomerId,
agentId,
fetch,
logger,
insertMessageBigquery,
},
handleCanopyWaveNonStream,
handleFireworksNonStream,
)
: useFireworks
? handleFireworksNonStream({
body: typedBody,
Expand Down
Loading