Ylem 已修改 . 還原成這個修訂版本
1 file changed, 87 insertions
mock-openai.py(檔案已創建)
@@ -0,0 +1,87 @@ | |||
1 | + | import time | |
2 | + | from typing import List, Optional | |
3 | + | ||
4 | + | from pydantic import BaseModel | |
5 | + | ||
6 | + | ||
7 | + | class ChatMessage(BaseModel): | |
8 | + | role: str | |
9 | + | content: str | |
10 | + | ||
11 | + | ||
12 | + | class ChatCompletionRequest(BaseModel): | |
13 | + | model: str = "mock-gpt-model" | |
14 | + | messages: List[ChatMessage] | |
15 | + | max_tokens: Optional[int] = 512 | |
16 | + | temperature: Optional[float] = 0.5 | |
17 | + | stream: Optional[bool] = False | |
18 | + | ||
19 | + | ||
20 | + | from fastapi import FastAPI, Depends, HTTPException, status | |
21 | + | from fastapi.security import HTTPBearer | |
22 | + | ||
23 | + | app = FastAPI(title="OpenAI-compatible API") | |
24 | + | ||
25 | + | bearer_scheme = HTTPBearer(auto_error=False) | |
26 | + | ||
27 | + | ||
28 | + | async def credentials(authorization=Depends(bearer_scheme)): | |
29 | + | if authorization and authorization.credentials == 'sk-1234': | |
30 | + | # api key is valid | |
31 | + | return authorization.credentials | |
32 | + | ||
33 | + | # raise http error 401 | |
34 | + | raise HTTPException( | |
35 | + | status_code=status.HTTP_401_UNAUTHORIZED, | |
36 | + | detail="Invalid API key", | |
37 | + | ) | |
38 | + | ||
39 | + | ||
40 | + | import asyncio | |
41 | + | import json | |
42 | + | ||
43 | + | ||
44 | + | async def _resp_async_generator(text_resp: str): | |
45 | + | # let's pretend every word is a token and return it over time | |
46 | + | tokens = text_resp.split(" ") | |
47 | + | ||
48 | + | for i, token in enumerate(tokens): | |
49 | + | chunk = { | |
50 | + | "id": i, | |
51 | + | "object": "chat.completion.chunk", | |
52 | + | "created": time.time(), | |
53 | + | "model": "mock-gpt-model", | |
54 | + | "choices": [{"delta": {"content": token + " "}}], | |
55 | + | } | |
56 | + | yield f"data: {json.dumps(chunk)}\n\n" | |
57 | + | await asyncio.sleep(0.1) | |
58 | + | yield "data: [DONE]\n\n" | |
59 | + | ||
60 | + | ||
61 | + | from starlette.responses import StreamingResponse | |
62 | + | ||
63 | + | ||
64 | + | @app.post("/v1/chat/completions", dependencies=[Depends(credentials)]) | |
65 | + | async def chat_completions(request: ChatCompletionRequest): | |
66 | + | if request.messages and request.messages[0].role == 'user': | |
67 | + | resp_content = "As a mock AI Assistant, I only can echo your last message:" + request.messages[-1].content | |
68 | + | else: | |
69 | + | resp_content = "As a mock AI Assistant, I only can echo your last message, but there were no messages!" | |
70 | + | ||
71 | + | if request.stream: | |
72 | + | return StreamingResponse(_resp_async_generator(resp_content), media_type="text/event-stream") | |
73 | + | ||
74 | + | return { | |
75 | + | "id": time.gmtime().tm_year, | |
76 | + | "object": "chat.completion", | |
77 | + | "created": time.time(), | |
78 | + | "model": request.model, | |
79 | + | "choices": [{ | |
80 | + | "message": ChatMessage(role="assistant", content=resp_content) | |
81 | + | }] | |
82 | + | } | |
83 | + | ||
84 | + | if __name__ == '__main__': | |
85 | + | import uvicorn | |
86 | + | ||
87 | + | uvicorn.run(app, host="0.0.0.0", port=8000) |
上一頁
下一頁