An image of the python software logo, with a person holding a beautiful painting over the logo.
|

How to Load a Custom Checkpoint into Stable Diffusion in Python

How to Load a Custom Checkpoint into Stable Diffusion in Python

Have you ever found yourself stuck trying to load a custom checkpoint into Stable Diffusion and just couldn’t find the right solution anywhere? Well, that happened to me recently. After a lot of trial and error, I finally figured it out, and I thought I’d share my solution so others can have an easier time. This method works great for checkpoints downloaded from sites like Civitai.

Why This Article?

I scoured the internet and various forums but just couldn’t get the full picture of how to load custom checkpoints into Stable Diffusion. That’s why I decided to document my findings and share them here. If you’re working on incorporating custom models into Stable Diffusion, you’ll find this guide quite handy!

The Code

Below is the code that I used to successfully load a custom checkpoint into the Stable Diffusion pipeline. I’ll break down what each part does so you can follow along easily.


if model == "sdxl":

    pipe = StableDiffusionXLPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        torch_dtype=torch.float16,
        variant="fp16"
    ).to(device)

    tensors = {}
    with safe_open("dreamshaper_8.safetensors", framework="pt", device=0) as f:
        for k in f.keys():
            tensors[k] = f.get_tensor(k)

    # Helper function to convert safetensors keys to match model's expected state_dict keys
    def convert_keys(model_dict, tensors, prefix):
        new_dict = {}
        for key in model_dict.keys():
            tensor_key = prefix + key
            if tensor_key in tensors:
                new_dict[key] = tensors[tensor_key]
            else:
                new_dict[key] = model_dict[key]
        return new_dict

    # Function to load weights into the model
    def load_weights(model, prefix):
        model_dict = model.state_dict()
        converted_dict = convert_keys(model_dict, tensors, prefix)
        model.load_state_dict(converted_dict)

    # Update the pipeline's models with the loaded weights
    load_weights(pipe.text_encoder, "text_encoder.")
    load_weights(pipe.vae, "vae.")
    load_weights(pipe.unet, "unet.")

    print("Loaded weights from the custom checkpoint")

Breaking Down the Code

  1. Initialize the Pipeline:

    if model == "sdxl":
        pipe = StableDiffusionXLPipeline.from_pretrained(
            "stabilityai/stable-diffusion-xl-base-1.0",
            torch_dtype=torch.float16,
            variant="fp16"
        ).to(device)
    

    Here, we initialize the Stable Diffusion XL pipeline with a pre-trained model. Make sure to set the appropriate torch_dtype and variant for your requirements.

  2. Load the Custom Checkpoint:

    tensors = {}
    with safe_open("dreamshaper_8.safetensors", framework="pt", device=0) as f:
        for k in f.keys():
            tensors[k] = f.get_tensor(k)
    

    We then load the tensors from the custom checkpoint file (dreamshaper_8.safetensors). This file could be downloaded from sites like Civitai.

  3. Convert Keys:

    def convert_keys(model_dict, tensors, prefix):
        new_dict = {}
        for key in model_dict.keys():
            tensor_key = prefix + key
            if tensor_key in tensors:
                new_dict[key] = tensors[tensor_key]
            else:
                new_dict[key] = model_dict[key]
        return new_dict
    

    The convert_keys function helps in converting the keys of the custom checkpoint to match the keys expected by the model’s state dictionary.

  4. Load Weights:

    def load_weights(model, prefix):
        model_dict = model.state_dict()
        converted_dict = convert_keys(model_dict, tensors, prefix)
        model.load_state_dict(converted_dict)
    

    The load_weights function is responsible for actually loading the converted weights into the model. It uses the convert_keys function to ensure everything matches up correctly.

  5. Update the Pipeline’s Models:

    load_weights(pipe.text_encoder, "text_encoder.")
    load_weights(pipe.vae, "vae.")
    load_weights(pipe.unet, "unet.")
    

    Finally, we use our load_weights function to update the text encoder, VAE, and UNet components of the pipeline with the custom weights.

  6. Confirmation:

    print("Loaded weights from the custom checkpoint")
    

    A simple print statement to confirm that the weights have been successfully loaded.

Conclusion

I hope you find this guide helpful. I know how frustrating it can be to look for a solution and come up empty-handed. By following these steps, you should be able to load custom checkpoints into Stable Diffusion with ease. Happy coding!

Feel free to leave any questions or feedback in the comments below. If there’s enough interest, I might explore more advanced topics in future articles. Happy diffusing!

Similar Posts