4 changed files with 302 additions and 1 deletions
@ -0,0 +1,142 @@ |
|||||
|
import type { RequestClient } from '../request-client'; |
||||
|
|
||||
|
import { beforeEach, describe, expect, it, vi } from 'vitest'; |
||||
|
|
||||
|
import { SSE } from './sse'; |
||||
|
|
||||
|
// 模拟 TextDecoder
|
||||
|
const OriginalTextDecoder = globalThis.TextDecoder; |
||||
|
|
||||
|
beforeEach(() => { |
||||
|
vi.stubGlobal( |
||||
|
'TextDecoder', |
||||
|
class { |
||||
|
private decoder = new OriginalTextDecoder(); |
||||
|
decode(value: Uint8Array, opts?: any) { |
||||
|
return this.decoder.decode(value, opts); |
||||
|
} |
||||
|
}, |
||||
|
); |
||||
|
}); |
||||
|
|
||||
|
// 创建 fetch mock
|
||||
|
const createFetchMock = (chunks: string[], ok = true) => { |
||||
|
const encoder = new TextEncoder(); |
||||
|
let index = 0; |
||||
|
return vi.fn().mockResolvedValue({ |
||||
|
ok, |
||||
|
status: ok ? 200 : 500, |
||||
|
body: { |
||||
|
getReader: () => ({ |
||||
|
read: async () => { |
||||
|
if (index < chunks.length) { |
||||
|
return { done: false, value: encoder.encode(chunks[index++]) }; |
||||
|
} |
||||
|
return { done: true, value: undefined }; |
||||
|
}, |
||||
|
}), |
||||
|
}, |
||||
|
}); |
||||
|
}; |
||||
|
|
||||
|
describe('sSE', () => { |
||||
|
let client: RequestClient; |
||||
|
let sse: SSE; |
||||
|
|
||||
|
beforeEach(() => { |
||||
|
vi.restoreAllMocks(); |
||||
|
client = { |
||||
|
getBaseUrl: () => 'http://localhost', |
||||
|
instance: { |
||||
|
interceptors: { |
||||
|
request: { |
||||
|
handlers: [], |
||||
|
}, |
||||
|
}, |
||||
|
}, |
||||
|
} as unknown as RequestClient; |
||||
|
sse = new SSE(client); |
||||
|
}); |
||||
|
|
||||
|
it('should call requestSSE when postSSE is used', async () => { |
||||
|
const spy = vi.spyOn(sse, 'requestSSE').mockResolvedValue(undefined); |
||||
|
await sse.postSSE('/test', { foo: 'bar' }, { headers: { a: '1' } }); |
||||
|
expect(spy).toHaveBeenCalledWith( |
||||
|
'/test', |
||||
|
{ foo: 'bar' }, |
||||
|
{ |
||||
|
headers: { a: '1' }, |
||||
|
method: 'POST', |
||||
|
}, |
||||
|
); |
||||
|
}); |
||||
|
|
||||
|
it('should throw error if fetch response not ok', async () => { |
||||
|
vi.stubGlobal('fetch', createFetchMock([], false)); |
||||
|
await expect(sse.requestSSE('/bad')).rejects.toThrow( |
||||
|
'HTTP error! status: 500', |
||||
|
); |
||||
|
}); |
||||
|
|
||||
|
it('should trigger onMessage and onEnd callbacks', async () => { |
||||
|
const messages: string[] = []; |
||||
|
const onMessage = vi.fn((msg: string) => messages.push(msg)); |
||||
|
const onEnd = vi.fn(); |
||||
|
|
||||
|
vi.stubGlobal('fetch', createFetchMock(['hello', ' world'])); |
||||
|
|
||||
|
await sse.requestSSE('/sse', undefined, { onMessage, onEnd }); |
||||
|
|
||||
|
expect(onMessage).toHaveBeenCalledTimes(2); |
||||
|
expect(messages.join('')).toBe('hello world'); |
||||
|
// onEnd 不再带参数
|
||||
|
expect(onEnd).toHaveBeenCalled(); |
||||
|
}); |
||||
|
|
||||
|
it('should apply request interceptors', async () => { |
||||
|
const interceptor = vi.fn(async (config) => { |
||||
|
config.headers['x-test'] = 'intercepted'; |
||||
|
return config; |
||||
|
}); |
||||
|
(client.instance.interceptors.request as any).handlers.push({ |
||||
|
fulfilled: interceptor, |
||||
|
}); |
||||
|
|
||||
|
// 创建 fetch mock,并挂到全局
|
||||
|
const fetchMock = createFetchMock(['data']); |
||||
|
vi.stubGlobal('fetch', fetchMock); |
||||
|
|
||||
|
await sse.requestSSE('/sse', undefined, {}); |
||||
|
|
||||
|
expect(interceptor).toHaveBeenCalled(); |
||||
|
expect(fetchMock).toHaveBeenCalledWith( |
||||
|
'http://localhost/sse', |
||||
|
expect.objectContaining({ |
||||
|
headers: expect.any(Headers), |
||||
|
}), |
||||
|
); |
||||
|
|
||||
|
const calls = fetchMock.mock?.calls; |
||||
|
expect(calls).toBeDefined(); |
||||
|
expect(calls?.length).toBeGreaterThan(0); |
||||
|
|
||||
|
const init = calls?.[0]?.[1] as RequestInit; |
||||
|
expect(init).toBeDefined(); |
||||
|
|
||||
|
const headers = init?.headers as Headers; |
||||
|
expect(headers?.get('x-test')).toBe('intercepted'); |
||||
|
expect(headers?.get('accept')).toBe('text/event-stream'); |
||||
|
}); |
||||
|
|
||||
|
it('should throw error when no reader', async () => { |
||||
|
vi.stubGlobal( |
||||
|
'fetch', |
||||
|
vi.fn().mockResolvedValue({ |
||||
|
ok: true, |
||||
|
status: 200, |
||||
|
body: null, |
||||
|
}), |
||||
|
); |
||||
|
await expect(sse.requestSSE('/sse')).rejects.toThrow('No reader'); |
||||
|
}); |
||||
|
}); |
||||
@ -0,0 +1,136 @@ |
|||||
|
import type { AxiosRequestHeaders, InternalAxiosRequestConfig } from 'axios'; |
||||
|
|
||||
|
import type { RequestClient } from '../request-client'; |
||||
|
import type { SseRequestOptions } from '../types'; |
||||
|
|
||||
|
/** |
||||
|
* SSE模块 |
||||
|
*/ |
||||
|
class SSE { |
||||
|
private client: RequestClient; |
||||
|
|
||||
|
constructor(client: RequestClient) { |
||||
|
this.client = client; |
||||
|
} |
||||
|
|
||||
|
public async postSSE( |
||||
|
url: string, |
||||
|
data?: any, |
||||
|
requestOptions?: SseRequestOptions, |
||||
|
) { |
||||
|
return this.requestSSE(url, data, { |
||||
|
...requestOptions, |
||||
|
method: 'POST', |
||||
|
}); |
||||
|
} |
||||
|
|
||||
|
/** |
||||
|
* SSE请求方法 |
||||
|
* @param url - 请求URL |
||||
|
* @param data - 请求数据 |
||||
|
* @param requestOptions - SSE请求选项 |
||||
|
*/ |
||||
|
public async requestSSE( |
||||
|
url: string, |
||||
|
data?: any, |
||||
|
requestOptions?: SseRequestOptions, |
||||
|
) { |
||||
|
const baseUrl = this.client.getBaseUrl() || ''; |
||||
|
|
||||
|
let axiosConfig: InternalAxiosRequestConfig<any> = { |
||||
|
url, |
||||
|
method: (requestOptions?.method as any) ?? 'GET', |
||||
|
headers: {} as AxiosRequestHeaders, |
||||
|
}; |
||||
|
const requestInterceptors = this.client.instance.interceptors |
||||
|
.request as any; |
||||
|
if ( |
||||
|
requestInterceptors.handlers && |
||||
|
requestInterceptors.handlers.length > 0 |
||||
|
) { |
||||
|
for (const handler of requestInterceptors.handlers) { |
||||
|
if (typeof handler?.fulfilled === 'function') { |
||||
|
const next = await handler.fulfilled(axiosConfig as any); |
||||
|
if (next) axiosConfig = next as InternalAxiosRequestConfig<any>; |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
const merged = new Headers(); |
||||
|
Object.entries( |
||||
|
(axiosConfig.headers ?? {}) as Record<string, string>, |
||||
|
).forEach(([k, v]) => merged.set(k, String(v))); |
||||
|
if (requestOptions?.headers) { |
||||
|
new Headers(requestOptions.headers).forEach((v, k) => merged.set(k, v)); |
||||
|
} |
||||
|
if (!merged.has('accept')) { |
||||
|
merged.set('accept', 'text/event-stream'); |
||||
|
} |
||||
|
|
||||
|
let bodyInit = requestOptions?.body ?? data; |
||||
|
const ct = (merged.get('content-type') || '').toLowerCase(); |
||||
|
if ( |
||||
|
bodyInit && |
||||
|
typeof bodyInit === 'object' && |
||||
|
!ArrayBuffer.isView(bodyInit as any) && |
||||
|
!(bodyInit instanceof ArrayBuffer) && |
||||
|
!(bodyInit instanceof Blob) && |
||||
|
!(bodyInit instanceof FormData) && |
||||
|
ct.includes('application/json') |
||||
|
) { |
||||
|
bodyInit = JSON.stringify(bodyInit); |
||||
|
} |
||||
|
const requestInit: RequestInit = { |
||||
|
...requestOptions, |
||||
|
method: axiosConfig.method, |
||||
|
headers: merged, |
||||
|
body: bodyInit, |
||||
|
}; |
||||
|
|
||||
|
const response = await fetch(safeJoinUrl(baseUrl, url), requestInit); |
||||
|
if (!response.ok) { |
||||
|
throw new Error(`HTTP error! status: ${response.status}`); |
||||
|
} |
||||
|
|
||||
|
const reader = response.body?.getReader(); |
||||
|
const decoder = new TextDecoder(); |
||||
|
|
||||
|
if (!reader) { |
||||
|
throw new Error('No reader'); |
||||
|
} |
||||
|
let isEnd = false; |
||||
|
while (!isEnd) { |
||||
|
const { done, value } = await reader.read(); |
||||
|
if (done) { |
||||
|
isEnd = true; |
||||
|
decoder.decode(new Uint8Array(0), { stream: false }); |
||||
|
requestOptions?.onEnd?.(); |
||||
|
reader.releaseLock?.(); |
||||
|
break; |
||||
|
} |
||||
|
const content = decoder.decode(value, { stream: true }); |
||||
|
requestOptions?.onMessage?.(content); |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
function safeJoinUrl(baseUrl: string | undefined, url: string): string { |
||||
|
if (!baseUrl) { |
||||
|
return url; // 没有 baseUrl,直接返回 url
|
||||
|
} |
||||
|
|
||||
|
// 如果 url 本身就是绝对地址,直接返回
|
||||
|
if (/^https?:\/\//i.test(url)) { |
||||
|
return url; |
||||
|
} |
||||
|
|
||||
|
// 如果 baseUrl 是完整 URL,就用 new URL
|
||||
|
if (/^https?:\/\//i.test(baseUrl)) { |
||||
|
return new URL(url, baseUrl).toString(); |
||||
|
} |
||||
|
|
||||
|
// 否则,当作路径拼接
|
||||
|
return `${baseUrl.replace(/\/+$/, '')}/${url.replace(/^\/+/, '')}`; |
||||
|
} |
||||
|
|
||||
|
export { SSE }; |
||||
Loading…
Reference in new issue