Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tools/server/tests/unit/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def test_server_slots():

def test_load_split_model():
global server
server.offline = False
server.model_hf_repo = "ggml-org/models"
server.model_hf_file = "tinyllamas/split/stories15M-q8_0-00001-of-00003.gguf"
server.model_alias = "tinyllama-split"
Expand Down
146 changes: 145 additions & 1 deletion tools/server/tests/unit/test_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def create_server():
]
)
def test_router_chat_completion_stream(model: str, success: bool):
# TODO: make sure the model is in cache (ie. ServerProcess.load_all()) before starting the router server
global server
server.start()
content = ""
Expand Down Expand Up @@ -48,3 +47,148 @@ def test_router_chat_completion_stream(model: str, success: bool):
else:
assert ex is not None
assert content == ""


def _get_model_status(model_id: str) -> str:
res = server.make_request("GET", "/models")
assert res.status_code == 200
for item in res.body.get("data", []):
if item.get("id") == model_id or item.get("model") == model_id:
return item["status"]["value"]
raise AssertionError(f"Model {model_id} not found in /models response")


def _wait_for_model_status(model_id: str, desired: set[str], timeout: int = 60) -> str:
deadline = time.time() + timeout
last_status = None
while time.time() < deadline:
last_status = _get_model_status(model_id)
if last_status in desired:
return last_status
time.sleep(1)
raise AssertionError(
f"Timed out waiting for {model_id} to reach {desired}, last status: {last_status}"
)


def _load_model_and_wait(
model_id: str, timeout: int = 60, headers: dict | None = None
) -> None:
load_res = server.make_request(
"POST", "/models/load", data={"model": model_id}, headers=headers
)
assert load_res.status_code == 200
assert isinstance(load_res.body, dict)
assert load_res.body.get("success") is True
_wait_for_model_status(model_id, {"loaded"}, timeout=timeout)


def test_router_unload_model():
global server
server.start()
model_id = "ggml-org/tinygemma3-GGUF:Q8_0"

_load_model_and_wait(model_id)

unload_res = server.make_request("POST", "/models/unload", data={"model": model_id})
assert unload_res.status_code == 200
assert unload_res.body.get("success") is True
_wait_for_model_status(model_id, {"unloaded"})


def test_router_models_max_evicts_lru():
global server
server.models_max = 2
server.start()

candidate_models = [
"ggml-org/tinygemma3-GGUF:Q8_0",
"ggml-org/test-model-stories260K",
"ggml-org/test-model-stories260K-infill",
]

# Load only the first 2 models to fill the cache
first, second, third = candidate_models[:3]

_load_model_and_wait(first, timeout=120)
_load_model_and_wait(second, timeout=120)

# Verify both models are loaded
assert _get_model_status(first) == "loaded"
assert _get_model_status(second) == "loaded"

# Load the third model - this should trigger LRU eviction of the first model
_load_model_and_wait(third, timeout=120)

# Verify eviction: third is loaded, first was evicted
assert _get_model_status(third) == "loaded"
assert _get_model_status(first) == "unloaded"


def test_router_no_models_autoload():
global server
server.no_models_autoload = True
server.start()
model_id = "ggml-org/tinygemma3-GGUF:Q8_0"

res = server.make_request(
"POST",
"/v1/chat/completions",
data={
"model": model_id,
"messages": [{"role": "user", "content": "hello"}],
"max_tokens": 4,
},
)
assert res.status_code == 400
assert "error" in res.body

_load_model_and_wait(model_id)

success_res = server.make_request(
"POST",
"/v1/chat/completions",
data={
"model": model_id,
"messages": [{"role": "user", "content": "hello"}],
"max_tokens": 4,
},
)
assert success_res.status_code == 200
assert "error" not in success_res.body


def test_router_api_key_required():
global server
server.api_key = "sk-router-secret"
server.start()

model_id = "ggml-org/tinygemma3-GGUF:Q8_0"
auth_headers = {"Authorization": f"Bearer {server.api_key}"}

res = server.make_request(
"POST",
"/v1/chat/completions",
data={
"model": model_id,
"messages": [{"role": "user", "content": "hello"}],
"max_tokens": 4,
},
)
assert res.status_code == 401
assert res.body.get("error", {}).get("type") == "authentication_error"

_load_model_and_wait(model_id, headers=auth_headers)

authed = server.make_request(
"POST",
"/v1/chat/completions",
headers=auth_headers,
data={
"model": model_id,
"messages": [{"role": "user", "content": "hello"}],
"max_tokens": 4,
},
)
assert authed.status_code == 200
assert "error" not in authed.body
28 changes: 23 additions & 5 deletions tools/server/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import re
import json
from json import JSONDecodeError
import sys
import requests
import time
Expand Down Expand Up @@ -83,6 +84,9 @@ class ServerProcess:
pooling: str | None = None
draft: int | None = None
api_key: str | None = None
models_dir: str | None = None
models_max: int | None = None
no_models_autoload: bool | None = None
lora_files: List[str] | None = None
enable_ctx_shift: int | None = False
draft_min: int | None = None
Expand Down Expand Up @@ -143,6 +147,10 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None:
server_args.extend(["--hf-repo", self.model_hf_repo])
if self.model_hf_file:
server_args.extend(["--hf-file", self.model_hf_file])
if self.models_dir:
server_args.extend(["--models-dir", self.models_dir])
if self.models_max is not None:
server_args.extend(["--models-max", self.models_max])
if self.n_batch:
server_args.extend(["--batch-size", self.n_batch])
if self.n_ubatch:
Expand Down Expand Up @@ -204,6 +212,8 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None:
server_args.extend(["--draft-min", self.draft_min])
if self.no_webui:
server_args.append("--no-webui")
if self.no_models_autoload:
server_args.append("--no-models-autoload")
if self.jinja:
server_args.append("--jinja")
else:
Expand Down Expand Up @@ -295,7 +305,13 @@ def make_request(
result = ServerResponse()
result.headers = dict(response.headers)
result.status_code = response.status_code
result.body = response.json() if parse_body else None
if parse_body:
try:
result.body = response.json()
except JSONDecodeError:
result.body = response.text
else:
result.body = None
print("Response from server", json.dumps(result.body, indent=2))
return result

Expand Down Expand Up @@ -434,8 +450,9 @@ def load_all() -> None:
@staticmethod
def tinyllama2() -> ServerProcess:
server = ServerProcess()
server.model_hf_repo = "ggml-org/models"
server.model_hf_file = "tinyllamas/stories260K.gguf"
server.offline = True # will be downloaded by load_all()
server.model_hf_repo = "ggml-org/test-model-stories260K"
server.model_hf_file = None
server.model_alias = "tinyllama-2"
server.n_ctx = 512
server.n_batch = 32
Expand Down Expand Up @@ -479,8 +496,8 @@ def bert_bge_small_with_fa() -> ServerProcess:
def tinyllama_infill() -> ServerProcess:
server = ServerProcess()
server.offline = True # will be downloaded by load_all()
server.model_hf_repo = "ggml-org/models"
server.model_hf_file = "tinyllamas/stories260K-infill.gguf"
server.model_hf_repo = "ggml-org/test-model-stories260K-infill"
server.model_hf_file = None
server.model_alias = "tinyllama-infill"
server.n_ctx = 2048
server.n_batch = 1024
Expand Down Expand Up @@ -537,6 +554,7 @@ def tinygemma3() -> ServerProcess:
@staticmethod
def router() -> ServerProcess:
server = ServerProcess()
server.offline = True # will be downloaded by load_all()
# router server has no models
server.model_file = None
server.model_alias = None
Expand Down
Loading