Skip to content

Is 7B llama speed expected to be slow? #19

@w32zhong

Description

@w32zhong

Hello, thank you for opening source such a solid work! Feel free to add my wechat (hellozhongwei) for an offline chat!

I know that, in the paper, the inference speed in Figure 2 is measured only by the gate_proj linear operation speed for 70B LLaMA. The speed bar looks impressive although I assume de-quantization and re-scaling in the CUDA kernel has huge overheads.

My hypothesis is the speed is due to single-batch memory-bound slowdown? But if this is the case, the full model inference for single batch should be faster as well? I do not have enough hardware resources, so I tested the smaller LLaMA 7B checkpoint: ChenMnZ/Llama-2-7b-EfficientQAT-w2g64-BitBLAS. However, the 2bit BitBLAS version is only around 14.5 tokens / s, but the huggingface native fp16 is faster (20 tokens / s) even if the latter one is operating in model parallelism.

My question is whether this is expected. Because I think BitBLAS has applied efficient schedulers on CUDA code already, it should have higher inference speed as you have reported in Figure 2. But why?

Test devices: 2x RTX3060

Test code:

import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import TextStreamer
from gptqmodel import GPTQModel

# ref model
ref_model_path = "NousResearch/Llama-2-7b-hf"

tokenizer = AutoTokenizer.from_pretrained(ref_model_path)
model = AutoModelForCausalLM.from_pretrained(ref_model_path,
    torch_dtype=torch.float16, device_map='auto', load_in_8bit=False)
streamer = TextStreamer(tokenizer)

start = time.time()
output = model.generate(
    **tokenizer("Solar eclipse is ", return_tensors="pt").to(model.device),
    max_new_tokens=256, streamer=streamer, use_cache=True
)
end = time.time()

output_len = output.shape[-1]
delta_time = end - start
print(output_len, delta_time, output_len / delta_time)

# 2-bit model in BitBLAS
model_path = "ChenMnZ/Llama-2-7b-EfficientQAT-w2g64-BitBLAS"

tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
model = GPTQModel.from_quantized(model_path)
streamer = TextStreamer(tokenizer)

start = time.time()
output = model.generate(
    **tokenizer("Solar eclipse is ", return_tensors="pt").to(model.device),
    max_new_tokens=256, streamer=streamer, use_cache=True
)
end = time.time()

output_len = output.shape[-1]
delta_time = end - start
print(output_len, delta_time, output_len / delta_time)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions