import kserve
import os
import torch
from kserve.vllm.vllm_model import VLLMModel
from kserve.vllm.utils import (
    build_vllm_engine_args,
    infer_vllm_supported_from_model_architecture,
    maybe_add_vllm_cli_parser,
)

import argparse

parser = argparse.ArgumentParser(parents=[kserve.model_server.parser])
parser.add_argument(
    '--model_dir', 
    type=str,
    default="/mnt/models",
    help='A URI pointer to the model directory')
parser.add_argument(
    '--lora_dir',
    type=str,
    default=None,
    help='A URI pointer to the lora adapter model directory')
parser.add_argument(
    "--trust_remote_code",
    action="store_true",
    default=True,
    help="allow loading of models and tokenizers with custom code",
)
parser = maybe_add_vllm_cli_parser(parser)
args, _ = parser.parse_known_args()
args.dtype = "float16"
args.model = args.model_dir
args.disable_log_stats = True
engine_args = build_vllm_engine_args(args)

if __name__ == "__main__":
    if not infer_vllm_supported_from_model_architecture(
            args.model,
            trust_remote_code=args.trust_remote_code,
        ):
        raise RuntimeError("vLLM is not available or not a supported model by vLLM")

    if args.lora_dir:
        engine_args.enable_lora = True
    # 判断lora模型是否在子目录下
        for root, dirs, files in os.walk(args.lora_dir):
            for file in files:
                if file.startswith("adapter_model."):
                    args.lora_dir = root
    model = VLLMModel(args.model_name, engine_args, lora_model=args.lora_dir)
    model.load()
    
    kserve.ModelServer().start([model])
