import gc
import re
import os
import ray
import ssl
import math
import time
import json
import torch
import argparse
import uvicorn

from fastapi import FastAPI
from pydantic import BaseModel
from vllm import LLM, SamplingParams
from vllm.distributed.parallel_state import destroy_model_parallel


TIMEOUT_KEEP_ALIVE = 5  # seconds.
app = FastAPI()


def convert_model_size_to_mn(value, quantization=False, bits=8):
    extra = 0
    fb = 2
    size = (value * fb)
    if quantization:
        extra = 0.06 * size
        if bits == 8:
            size = size / 2
        elif bits == 4:
            size = size / 4
        else:
            size = size / 2  # default
    return size + extra


def compute_inference_activation_memory(context_len, hidden_size, num_attention_heads):
    return context_len * hidden_size * 5 * 2 + context_len * context_len * num_attention_heads * 2


def compute_model_size(vocab_size, num_hidden_layers, hidden_size, intermediate_size):
    return (vocab_size * hidden_size * 2 + num_hidden_layers * 4 * hidden_size * hidden_size +
            num_hidden_layers * 3 * intermediate_size * hidden_size)


def calc_memory_by_formula(model_path):
    input_len = 4096
    output_len = 4096
    context_len = input_len + output_len
    # Read json file
    with open(os.path.join(model_path, 'config.json'), 'r') as f:
        c = json.load(f)
    # Inference Memory
    inference_memory = (2 * context_len * 2 * 2 * c['hidden_size'] * c['num_hidden_layers']) / (1024 ** 3)
    # Activation Memory
    activation_memory = compute_inference_activation_memory(context_len, c['hidden_size'], c['num_attention_heads']) / (1024 ** 3)
    # Model Memory
    model_size = compute_model_size(c['vocab_size'], c['num_hidden_layers'], c['hidden_size'], c['intermediate_size'])
    if 'quantization_config' in c:
        model_size = convert_model_size_to_mn(model_size, True, c['quantization_config']['bits'])
    else:
        model_size = convert_model_size_to_mn(model_size)
    model_size = model_size / (1024 ** 3)
    # Total
    # total_memory = inference_memory + model_size + activation_memory + over_head
    # print('Total: ', total_memory)
    print('Inference Memory: ', inference_memory)
    print('Model Memory: ', model_size)
    print('Activation Memory: ', activation_memory)
    return model_size, inference_memory, activation_memory



def calc_memory_by_load_model(model_path, tensor_parallel_size):
    try:
        llm = LLM(model=model_path, max_model_len=4096*2, tensor_parallel_size=tensor_parallel_size)
        # TODO: 多GPU時發生錯誤 'RayWorkerWrapper' object has no attribute 'model_runner'
        if tensor_parallel_size > 1:
            model_memory = llm.llm_engine.model_executor.driver_worker.worker.model_runner.model_memory_usage / (1024 ** 3) * tensor_parallel_size
        else:
            model_memory = llm.llm_engine.model_executor.driver_worker.model_runner.model_memory_usage / (1024 ** 3)
        destroy_model_parallel()
        del llm.llm_engine.model_executor.driver_worker
        del llm
        gc.collect()
        torch.cuda.empty_cache()
        ray.shutdown()
        # torch.distributed.destroy_process_group()
        print("Successfully delete the llm pipeline and free the GPU memory!")
    except Exception as e:
        model_memory = None
        print(e)
    # except ValueError as e:
    #     group = re.search('KV cache \((\d+)\)', str(e))
    #     if group:
    #         max_model_len = int(group.group(1))
    #         error_msg = f"The model's max seq len (65536) is larger than the maximum number of tokens, Try decreasing `max_model_len` to {max_model_len}."
    #     else:
    #         error_msg = str(e)
    # except Exception as e:
    #     if str(e).startswith('CUDA out of memory.'):
    #         print('CUDA out of memory.')
    #         error_msg = 'CUDA out of memory.'
    #     else:
    #         error_msg = str(e)
            # Delete the llm object and free the memory
    return model_memory


class Item(BaseModel):
    llm_model_path: str
    tensor_parallel_size: int = 1



@app.post("/benchmark")
def benchmark_api(item: Item):
    r_model_memory = calc_memory_by_load_model(item.llm_model_path, item.tensor_parallel_size)
    p_model_memory, p_inference_memory, p_activation_memory = calc_memory_by_formula(item.llm_model_path)

    return {
        "Real Model Memory Usage (GB)": math.ceil(r_model_memory*100)/100 if r_model_memory else None, 
        "Predict Model Memory Usage (GB)": math.ceil(p_model_memory*100)/100, 
        "Predict KV Cache Memory Usage (GB)": math.ceil(p_inference_memory*100)/100,
        "Predict Activation Memory Usage (GB)": math.ceil(p_activation_memory*100)/100
    }


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # server setting
    parser.add_argument("--host", type=str, default=None)
    parser.add_argument("--port", type=int, default=8000)
    parser.add_argument("--ssl-keyfile", type=str, default=None)
    parser.add_argument("--ssl-certfile", type=str, default=None)
    parser.add_argument("--ssl-ca-certs",
                        type=str,
                        default=None,
                        help="The CA certificates file")
    parser.add_argument(
        "--ssl-cert-reqs",
        type=int,
        default=int(ssl.CERT_NONE),
        help="Whether client certificate is required (see stdlib ssl module's)"
    )
    parser.add_argument(
        "--root-path",
        type=str,
        default=None,
        help="FastAPI root_path when app is behind a path based routing proxy")
    parser.add_argument("--log-level", type=str, default="debug")
    
    # run server
    args = parser.parse_args()
    app.root_path = args.root_path
    uvicorn.run(app,
                host=args.host,
                port=args.port,
                log_level=args.log_level,
                timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
                ssl_keyfile=args.ssl_keyfile,
                ssl_certfile=args.ssl_certfile,
                ssl_ca_certs=args.ssl_ca_certs,
                ssl_cert_reqs=args.ssl_cert_reqs)
