Loading Transformers Models

guide
LLM
HuggingFace
Author

im@johnho.ca

Published

Tuesday, August 5, 2025

Abstract
and a github template you can use to quickly create HuggingFace Spaces

Introduction

As mentioned in a previous post the transformers library is a powerful tool for working with a host of AI models either locally or by “deploying” your own Space.

For the purpose of making the creation of HuggingFace Spaces easier, I have created a Github Template with action setup to CI/CD to the Space’s repository and to create the requirements.txt file automatically using uv1

After a few attempts to debug a phathom OOM bug2 while building this Video Caption Demo using VLMs, here’s a few lessons learnt on how to load models using the transformers library in the most memory-efficient manner.

Loading Transformers Models

Here’s a code snippet for an example HuggingFace Space Gradio app to illustrate the best practice for loading transformers models efficiently. The details are highlighted in the code-annotation on the relevant lines.

app.py
import spaces, torch, time
import gradio as gr
from transformers import (
    AutoModelForImageTextToText,
    Gemma3nForConditionalGeneration,
    AutoProcessor,
    BitsAndBytesConfig,
)

# Flash Attention for ZeroGPU
import subprocess

subprocess.run(
    "pip install flash-attn --no-build-isolation",
    env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
    shell=True,
)

# Set target DEVICE and DTYPE
DTYPE = (
    torch.bfloat16
    if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    else torch.float16
)
DEVICE = "auto"
print(f"Device: {DEVICE}, dtype: {DTYPE}")


def load_model(
    model_name: str = "chancharikm/qwen2.5-vl-7b-cam-motion-preview",
    use_flash_attention: bool = False,
    apply_quantization: bool = True,
):
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,  # Load model weights in 4-bit
        bnb_4bit_quant_type="nf4",  # Use NF4 quantization (or "fp4")
        bnb_4bit_compute_dtype=DTYPE,  # Perform computations in bfloat16/float16
        bnb_4bit_use_double_quant=True,  # Optional: further quantization for slightly more memory saving
    )

    # Determine model family from model name
    model_family = model_name.split("/")[-1].split("-")[0]

    # Common model loading arguments
    common_args = {
        "torch_dtype": DTYPE,
        "device_map": DEVICE,
        "low_cpu_mem_usage": True,
        "quantization_config": bnb_config if apply_quantization else None,
    }
    if use_flash_attention:
        common_args["attn_implementation"] = "flash_attention_2"

    # Load model based on family
    match model_family:
        # case "qwen2.5" | "Qwen2.5":
        #     model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        #         model_name, **common_args
        #     )
        case "gemma":
            model = Gemma3nForConditionalGeneration.from_pretrained(
                model_name, **common_args
            )
        case "InternVL3":
            model = AutoModelForImageTextToText.from_pretrained(
                model_name, **common_args
            )
        case _:
            raise ValueError(f"Unsupported model family: {model_family}")

    # Set model to evaluation mode for inference (disables dropout, etc.)
    return model.eval()


def load_processor(model_name="Qwen/Qwen2.5-VL-7B-Instruct"):
    return AutoProcessor.from_pretrained(
        model_name,
        device_map=DEVICE,
        use_fast=True,
        torch_dtype=DTYPE,
    )


print("Loading Models and Processors...")
MODEL_ZOO = {
    "qwen2.5-vl-7b-instruct": load_model(
        model_name="Qwen/Qwen2.5-VL-7B-Instruct",
        use_flash_attention=False,
        apply_quantization=False,
    ),
    "InternVL3-1B-hf": load_model(
        model_name="OpenGVLab/InternVL3-1B-hf",
        use_flash_attention=False,
        apply_quantization=False,
    ),
    "InternVL3-2B-hf": load_model(
        model_name="OpenGVLab/InternVL3-2B-hf",
        use_flash_attention=False,
        apply_quantization=False,
    ),
    "InternVL3-8B-hf": load_model(
        model_name="OpenGVLab/InternVL3-8B-hf",
        use_flash_attention=False,
        apply_quantization=True,
    ),
}

PROCESSORS = {
    "qwen2.5-vl-7b-instruct": load_processor("Qwen/Qwen2.5-VL-7B-Instruct"),
    "InternVL3-1B-hf": load_processor("OpenGVLab/InternVL3-1B-hf"),
    "InternVL3-2B-hf": load_processor("OpenGVLab/InternVL3-2B-hf"),
    "InternVL3-8B-hf": load_processor("OpenGVLab/InternVL3-8B-hf"),
}
print("Models and Processors Loaded!")


# Our Inference Function
@spaces.GPU(duration=120)
def video_inference(
    video_path: str,
    prompt: str,
    model_name: str,
    fps: int = 8,
    max_tokens: int = 512,
    temperature: float = 0.1,
):
    s_time = time.time()
    model = MODEL_ZOO[model_name]
    processor = PROCESSORS[model_name]
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "video",
                    "video": video_path,
                },
                {"type": "text", "text": prompt},
            ],
        }
    ]
    with torch.no_grad():
        model_family = model_name.split("-")[0]
        match model_family:
            case "InternVL3":
                inputs = processor.apply_chat_template(
                    messages,
                    add_generation_prompt=True,
                    tokenize=True,
                    return_dict=True,
                    return_tensors="pt",
                    fps=fps,
                    # num_frames = 8
                ).to("cuda", dtype=DTYPE)

                output = model.generate(
                    **inputs,
                    max_new_tokens=max_tokens,
                    temperature=float(temperature),
                    do_sample=temperature > 0.0,
                )
                output_text = processor.decode(
                    output[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
                )
            case _:
                raise ValueError(f"{model_name} is not currently supported")
    return {
        "output_text": output_text,
        "fps": fps,
        "inference_time": time.time() - s_time,
    }


# the Gradio App
app = gr.Interface(
    fn=inference,
    inputs=[
        gr.Video(label="Input Video"),
        gr.Textbox(
            label="Prompt",
            lines=3,
            info="Some models like [cam motion](https://huggingface.co/chancharikm/qwen2.5-vl-7b-cam-motion-preview) are trained specific prompts",
            value="Describe the camera motion in this video.",
        ),
        gr.Dropdown(label="Model", choices=list(MODEL_ZOO.keys())),
        gr.Number(
            label="FPS",
            info="inference sampling rate (Qwen2.5VL is trained on videos with 8 fps); a value of 0 means the FPS of the input video will be used",
            value=8,
            minimum=0,
            step=1,
        ),
        gr.Slider(
            label="Max Tokens",
            info="maximum number of tokens to generate",
            value=128,
            minimum=32,
            maximum=512,
            step=32,
        ),
        gr.Slider(
            label="Temperature",
            value=0.0,
            minimum=0.0,
            maximum=1.0,
            step=0.1,
        ),
    ],
    outputs=gr.JSON(label="Output JSON"),
    title="Video Chat with VLM",
    description='comparing various "small" VLMs on the task of video captioning',
    api_name="video_inference",
)
app.launch(
    mcp_server=True
)
1
you’ll need spaces for access to ZeroGPU, torch for setting device and data type for your models, and time is good to understand how long each inference takes.
2
you’ll need gradio for creating Gradio interface when deploying to HuggingFace Space.
3
bitsandbytes for quantization of models “on the fly” but requires a CUDA-enabled GPU3
4
installing flash-attn for ZeroGPU requires special handling
5
Using torch.float16 (half-precision) or torch.bfloat16 reduces memory usage for model weights and activations by half compared to torch.float32. bfloat16 is generally preferred for training stability due to its wider dynamic range, but float16 is often sufficient for inference and widely supported.
6
accelerate (required) will try to fit the model layers across available GPUs as much as possible and then offload the rest to CPU.
7
This flag (also requires accelerate) tells transformers to load the model directly to the target device or stream it in a more memory-efficient way, avoiding a large CPU RAM spike.
8
Flash Attention is an optimized attention algorithm designed to address the memory and computational bottlenecks of the standard attention mechanism in Transformers. An CUDA enabled GPU is required. Read this lecture note for a deep dive.
9
Always set your model to evaluation mode (model.eval()) for inference. This disables layers like Dropout and BatchNorm, which behave differently during inference and can sometimes free up minor memory.
10
Models should be loaded well before inference to avoid inflating inference time.
11
This is the max duration (in seconds) that the inference function can take before timing out. The user must also have this amount of available ZeroGPU time limit left before calling the function.
12
Wrap your inference calls within with torch.no_grad() context manager. This prevents PyTorch from building the computation graph for gradients, saving a significant amount of memory for intermediate activations.
13
this turns on the creation of a MCP server (in addition to the API), but will require a detailed docstring for the inference function which we don’t have. See the official doc for more details on Gradio’s MCP Server.
14
with some models like Qwen2.5VL the fps will determine the frame extraction rate.
15
temperature is a parameter that controls the randomness of the generated output. It influences how “creative” or “deterministic” the model’s responses will be.
16
This is crucial. temperature only has an effect if sampling is enabled. If do_sample=False (which is the default if temperature is not explicitly set or is 0), the model will perform greedy decoding, always picking the most probable next token, regardless of the temperature value.
17
when you are trying to load a gated model, like Gemma-3n-E4B-it, the transformers library automatically checks for a HF_TOKEN environment variable. If set, it will use that token for authentication without needing login() or the use_auth_token variable in .from_pretrained()
18
For InternVL3 the num_frames parameter is also available but is mutually exclusive with fps.

Resources

in addition to the few tips and tricks shared above, here are a few more resources that are worth checking out:

Footnotes

  1. this is optional and in some cases it might be better to create the requirements.txt file manually, especially when working with ZeroGPU since it would be non-trivial to reproduce the space’s hardware setup. However, note that some spaces like this one for gemma-3n-E4B-it built by the HuggingFace Team does us uv.↩︎

  2. as of the writing of this post, ZeroGPU uses a H200 GPU which has about 141GB of VRAM so it’s unlikely to encounter any OOM issues. My “bug” was actually due to version pinning of the torch, torchvision, and transformers libraries which I realized once looking at another space’s requirements.txt file. In the end I was able to load up 7 VLMs, including one with 8-billion parameters. But the “real bug” is in not being able to use flash attention with the Qwen2.5VL model.↩︎

  3. without access to CUDA? you could actually use this space by the HuggingFace team to create your own quantized model using their hardware and the bitsandbytes library.↩︎

Reuse