【Arc A770】「 Llama-3.1-Swallow-8B-Instruct-v0.3」をGradioで使用する


使用したPC

Ubuntu 24.04
Python 3.12
Intel Arc A770

Python環境

通常のPyTorchとIPEX (Intel Extension for PyTorch)の2つの環境で動作確認しました。

PyTorch

accelerate==1.3.0
gradio==5.12.0
torch==2.6.0+xpu
torchao==0.8.0
transformers==4.48.1

IPEX

accelerate==1.3.0
gradio==5.12.0
torch==2.5.1+cxx11.abi
torchao==0.8.0
transformers==4.48.1

Pythonスクリプト

TorchAOで8bit量子化を行っています。4bit量子化はうまくいきませんでした。

import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, TorchAoConfig
from threading import Thread
import torch

model_name = "Llama-3.1-Swallow-8B-Instruct-v0.3"

system_prompt_text = "あなたは誠実で優秀な日本人のアシスタントです。"
init = {
    "role": "system",
    "metadata": {"title": None},
    "content": system_prompt_text,
}

quantization_config = TorchAoConfig("int8_weight_only", group_size=128)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=quantization_config,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
streamer = TextIteratorStreamer(
    tokenizer,
    skip_prompt=True,
    skip_special_tokens=True
)

def user(
    message: str,
    history: list[dict]
):
    if len(history) == 0:
        history.insert(0, init)
    history.append(
        {
            "role": "user", 
            "metadata": {"title": None},
            "content": message
        }
    )
    return "", history

def bot(
    history: list[dict]
):
    input_tensors = tokenizer.apply_chat_template(
        history,
        add_generation_prompt=True,
        return_tensors="pt",
        return_dict=True
    ).to(model.device)

    input_ids = input_tensors["input_ids"]
    attention_mask = input_tensors["attention_mask"]

    generation_kwargs = dict(
        inputs=input_ids,
        attention_mask=attention_mask,
        streamer=streamer,
        max_new_tokens=512,
        temperature=0.6,
        top_p=0.9,
        pad_token_id=tokenizer.eos_token_id,
    )
    history.append({"role": "assistant", "content": ""})

    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    for new_text in streamer:
        history[-1]["content"] += new_text
        yield history

with gr.Blocks() as demo:
    chatbot = gr.Chatbot(type="messages")
    msg = gr.Textbox()
    clear = gr.ClearButton([msg, chatbot], value="新しいチャットを開始")
    
    msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot, chatbot, chatbot
    )

demo.launch()