Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/code/agent/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
VENV_DIR = os.getenv('VENV_DIR', WORK_DIR + '/venv')
MODEL_DIR = os.getenv('MODEL_DIR', MNT_DIR + '/models')
SKIP_SNAPSHOT_LOADING = os.getenv('SKIP_SNAPSHOT_LOADING')
SKIP_SNAPSHOT_DOWNLOADING = os.getenv('SKIP_SNAPSHOT_DOWNLOADING').lower() == 'true'
# API函数启动时是否跳过加载NAS中的custom_nodes.zip到实例磁盘,若跳过则可能遇到部分插件在多个实例并发读写NAS中插件目录时的冲突情况
SKIP_NODES_LOADING = os.getenv('SKIP_NODES_LOADING', '').lower() == 'true'
SNAPSHOT_DIR = MNT_DIR + '/snapshots'
Expand All @@ -36,6 +37,7 @@
f"{MNT_DIR}/output",
"--disable-metadata"
]
COMFY_USE_CPU = os.getenv('COMFY_USE_CPU').lower() == 'true'
SD_DIR = os.getenv('SD_DIR', WORK_DIR + '/stable-diffusion-webui')
SD_PROCESS_PORT = 7860
SD_BOOT_CMD = [
Expand All @@ -53,6 +55,9 @@
SD_PROCESS_PORT = 7861

if BACKEND_TYPE == TYPE_COMFYUI:
if COMFY_USE_CPU:
COMFYUI_BOOT_CMD.append("--cpu")

BACKEND_PROCESS_PORT = COMFYUI_PROCESS_PORT
BOOT_CMD = COMFYUI_BOOT_CMD
else:
Expand Down
12 changes: 6 additions & 6 deletions src/code/agent/routes/serverless_api_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def run_http():
返回值:
输出的图片数组
"""
body = request.get_json()
payload = request.get_json()
stream = is_true(request.args.get("stream"))
output_base64 = is_true(request.args.get("output_base64"))
output_oss = is_true(request.args.get("output_oss"))
Expand All @@ -86,7 +86,7 @@ def run_http():
if not stream:
try:
return self.service.run(
body,
payload,
output_base64=output_base64,
output_oss=output_oss,
task_id=task_id,
Expand Down Expand Up @@ -129,7 +129,7 @@ def run_prompt_task():
"""
try:
result = self.service.run(
body,
payload,
output_base64=output_base64,
output_oss=output_oss,
callback=do_streaming,
Expand Down Expand Up @@ -188,15 +188,15 @@ def run_ws(ws: Server):
),
)

# 获取第一个 message 作为输入的 prompt
# 获取第一个 message 作为输入有效载荷 payload
data = ws.receive()
prompt = json.loads(data)
payload = json.loads(data)

def callback(msg):
ws.send(msg)

results = self.service.run(
prompt,
payload,
output_base64=output_base64,
output_oss=output_oss,
callback=callback,
Expand Down
10 changes: 5 additions & 5 deletions src/code/agent/services/serverlessapi/serverless_api_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,11 @@ def get_oss_store(self):
constants.OSS_EXPIRES_IN_SECOND,
)

def api_prompt(self, client_id: str, prompt: Any):
def api_prompt(self, client_id: str, payload: Any):
"""
出图
"""
req = {"client_id": client_id, "prompt": prompt}
req = {"client_id": client_id, **payload}
res = requests.post(
os.path.join(self.endpoint, "prompt"),
json=req,
Expand Down Expand Up @@ -325,7 +325,7 @@ def get_status_from_store(self, task_id: str):

def run(
self,
prompt: map,
payload: map,
output_base64=False,
output_oss=False,
callback=None,
Expand All @@ -338,7 +338,7 @@ def run(
try:

# 解析请求中是否存在 base64、http url 形式的图片
prompt = self.parse_prompt(prompt)
payload["prompt"] = self.parse_prompt(payload.get("prompt", {}))

client_id = ""
prompt_id = ""
Expand Down Expand Up @@ -395,7 +395,7 @@ def on_message(ws: websocket.WebSocket, message: str):
while client_id == "":
time.sleep(0.1)

prompt_result = self.api_prompt(client_id, prompt)
prompt_result = self.api_prompt(client_id, payload)
prompt_id = prompt_result.get("prompt_id", "")

# 如果 task id 未指定,则使用 prompt id
Expand Down
49 changes: 32 additions & 17 deletions src/code/agent/services/workspace/snapshot_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def load(self, snapshot_path: str) -> Dict:

service.sub_status = StartingSubStatus.EXTRACTING.value
with self.timer("Extracting snapshot") as t_extract:
self._extract()
self._extract(snapshot_path)
stage_cost["time_extract"] = round(t_extract.elapsed, 2)

self._create_symlinks(snapshot_path)
Expand All @@ -47,7 +47,7 @@ def _download(self, snapshot_path: str):
pass

@abstractmethod
def _extract(self):
def _extract(self, snapshot_path: str):
"""解压依赖"""
pass

Expand All @@ -69,7 +69,7 @@ def _download(self, snapshot_path: str):
if os.path.exists(cache_path):
file_ops.copy(cache_path, f"{constants.WORK_DIR}/.cache.zip")

def _extract(self):
def _extract(self, snapshot_path: str):
file_ops.extract(f"{constants.WORK_DIR}/venv.tar")
file_ops.remove(f"{constants.WORK_DIR}/venv.tar")
file_ops.extract(f"{constants.WORK_DIR}/comfyui.zip")
Expand Down Expand Up @@ -98,6 +98,9 @@ def _clear(self):
file_ops.remove(constants.VENV_DIR)

def _download(self, snapshot_path: str):
if constants.SKIP_SNAPSHOT_DOWNLOADING:
return

file_ops.copy(f"{snapshot_path}/venv.tar", f"{constants.WORK_DIR}/venv.tar")
file_ops.copy(f"{snapshot_path}/comfyui.zip", f"{constants.WORK_DIR}/comfyui.zip")
cache_path = f"{snapshot_path}/.cache.zip"
Expand All @@ -106,19 +109,31 @@ def _download(self, snapshot_path: str):
if not constants.SKIP_NODES_LOADING:
file_ops.copy(f"{snapshot_path}/custom_nodes.zip", f"{constants.WORK_DIR}/custom_nodes.zip")

def _extract(self):
file_ops.extract(f"{constants.WORK_DIR}/venv.tar")
file_ops.remove(f"{constants.WORK_DIR}/venv.tar")
file_ops.extract(f"{constants.WORK_DIR}/comfyui.zip")
file_ops.remove(f"{constants.WORK_DIR}/comfyui.zip")
cache_path = f"{constants.WORK_DIR}/.cache.zip"
if os.path.exists(cache_path):
file_ops.extract(cache_path)
file_ops.remove(cache_path)
if not constants.SKIP_NODES_LOADING:
file_ops.remove(f"{constants.COMFYUI_DIR}/custom_nodes") # 解压时不会强制覆盖,需手动删除解压时会产生冲突的文件
file_ops.extract(f"{constants.WORK_DIR}/custom_nodes.zip", output_dir=f"{constants.COMFYUI_DIR}/custom_nodes")
file_ops.remove(f"{constants.WORK_DIR}/custom_nodes.zip")
def _extract(self, snapshot_path: str):
if constants.SKIP_SNAPSHOT_DOWNLOADING:
file_ops.extract(f"{snapshot_path}/venv.tar", output_dir=f"{constants.WORK_DIR}")
file_ops.extract(f"{snapshot_path}/comfyui.zip", output_dir=f"{constants.WORK_DIR}")

cache_path = f"{snapshot_path}/.cache.zip"
if os.path.exists(cache_path):
file_ops.extract(cache_path, output_dir=f"{constants.WORK_DIR}/.cache")

if not constants.SKIP_NODES_LOADING:
file_ops.remove(f"{constants.COMFYUI_DIR}/custom_nodes")
file_ops.extract(f"{snapshot_path}/custom_nodes.zip", output_dir=f"{constants.COMFYUI_DIR}/custom_nodes")
else:
file_ops.extract(f"{constants.WORK_DIR}/venv.tar")
file_ops.remove(f"{constants.WORK_DIR}/venv.tar")
file_ops.extract(f"{constants.WORK_DIR}/comfyui.zip")
file_ops.remove(f"{constants.WORK_DIR}/comfyui.zip")
cache_path = f"{constants.WORK_DIR}/.cache.zip"
if os.path.exists(cache_path):
file_ops.extract(cache_path)
file_ops.remove(cache_path)
if not constants.SKIP_NODES_LOADING:
file_ops.remove(f"{constants.COMFYUI_DIR}/custom_nodes") # 解压时不会强制覆盖,需手动删除解压时会产生冲突的文件
file_ops.extract(f"{constants.WORK_DIR}/custom_nodes.zip", output_dir=f"{constants.COMFYUI_DIR}/custom_nodes")
file_ops.remove(f"{constants.WORK_DIR}/custom_nodes.zip")

def _create_symlinks(self, snapshot_path: str):
file_ops.create_symlink(
Expand Down Expand Up @@ -147,7 +162,7 @@ def _download(self, snapshot_path: str):
if os.path.exists(cache_path):
file_ops.copy(cache_path, f"{constants.WORK_DIR}/.cache.zip")

def _extract(self):
def _extract(self, snapshot_path: str):
file_ops.extract(f"{constants.WORK_DIR}/venv.tar")
file_ops.remove(f"{constants.WORK_DIR}/venv.tar")
file_ops.extract(f"{constants.WORK_DIR}/stable-diffusion-webui.zip")
Expand Down