22import json
33from threading import Lock
44from typing import List , Optional , Union , Iterator , Dict
5- from typing_extensions import TypedDict , Literal
5+ from typing_extensions import TypedDict , Literal , Annotated
66
77import llama_cpp
88
9- from fastapi import Depends , FastAPI
9+ from fastapi import Depends , FastAPI , APIRouter
1010from fastapi .middleware .cors import CORSMiddleware
1111from pydantic import BaseModel , BaseSettings , Field , create_model_from_typeddict
1212from sse_starlette .sse import EventSourceResponse
1313
1414
1515class Settings (BaseSettings ):
16- model : str = os . environ . get ( "MODEL" , "null" )
16+ model : str
1717 n_ctx : int = 2048
1818 n_batch : int = 512
1919 n_threads : int = max ((os .cpu_count () or 2 ) // 2 , 1 )
@@ -27,25 +27,29 @@ class Settings(BaseSettings):
2727 vocab_only : bool = False
2828
2929
30- app = FastAPI (
31- title = "🦙 llama.cpp Python API" ,
32- version = "0.0.1" ,
33- )
34- app .add_middleware (
35- CORSMiddleware ,
36- allow_origins = ["*" ],
37- allow_credentials = True ,
38- allow_methods = ["*" ],
39- allow_headers = ["*" ],
40- )
30+ router = APIRouter ()
31+
32+ llama : Optional [llama_cpp .Llama ] = None
4133
42- llama : llama_cpp . Llama = None
43- def init_llama (settings : Settings = None ):
34+
35+ def create_app (settings : Optional [ Settings ] = None ):
4436 if settings is None :
4537 settings = Settings ()
38+ app = FastAPI (
39+ title = "🦙 llama.cpp Python API" ,
40+ version = "0.0.1" ,
41+ )
42+ app .add_middleware (
43+ CORSMiddleware ,
44+ allow_origins = ["*" ],
45+ allow_credentials = True ,
46+ allow_methods = ["*" ],
47+ allow_headers = ["*" ],
48+ )
49+ app .include_router (router )
4650 global llama
4751 llama = llama_cpp .Llama (
48- settings .model ,
52+ model_path = settings .model ,
4953 f16_kv = settings .f16_kv ,
5054 use_mlock = settings .use_mlock ,
5155 use_mmap = settings .use_mmap ,
@@ -60,8 +64,12 @@ def init_llama(settings: Settings = None):
6064 if settings .cache :
6165 cache = llama_cpp .LlamaCache ()
6266 llama .set_cache (cache )
67+ return app
68+
6369
6470llama_lock = Lock ()
71+
72+
6573def get_llama ():
6674 with llama_lock :
6775 yield llama
@@ -117,8 +125,6 @@ def get_llama():
117125 "Repeat penalty is a hyperparameter used to penalize the repetition of token sequences during text generation. It helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient."
118126)
119127
120-
121-
122128class CreateCompletionRequest (BaseModel ):
123129 prompt : Union [str , List [str ]] = Field (
124130 default = "" ,
@@ -162,7 +168,7 @@ class Config:
162168CreateCompletionResponse = create_model_from_typeddict (llama_cpp .Completion )
163169
164170
165- @app .post (
171+ @router .post (
166172 "/v1/completions" ,
167173 response_model = CreateCompletionResponse ,
168174)
@@ -204,7 +210,7 @@ class Config:
204210CreateEmbeddingResponse = create_model_from_typeddict (llama_cpp .Embedding )
205211
206212
207- @app .post (
213+ @router .post (
208214 "/v1/embeddings" ,
209215 response_model = CreateEmbeddingResponse ,
210216)
@@ -257,7 +263,7 @@ class Config:
257263CreateChatCompletionResponse = create_model_from_typeddict (llama_cpp .ChatCompletion )
258264
259265
260- @app .post (
266+ @router .post (
261267 "/v1/chat/completions" ,
262268 response_model = CreateChatCompletionResponse ,
263269)
@@ -306,7 +312,7 @@ class ModelList(TypedDict):
306312GetModelResponse = create_model_from_typeddict (ModelList )
307313
308314
309- @app .get ("/v1/models" , response_model = GetModelResponse )
315+ @router .get ("/v1/models" , response_model = GetModelResponse )
310316def get_models () -> ModelList :
311317 return {
312318 "object" : "list" ,
0 commit comments