Created
May 28, 2025 12:01
-
-
Save a-r-r-o-w/b34d83641a3f80e26759789d5eec3280 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Reference: https://github.com/arcee-ai/mergekit/blob/488957e8e67c82861ecf63ef761f6bc59122dc74/mergekit/scripts/extract_lora.py | |
| import argparse | |
| import torch | |
| from safetensors.torch import load_file, save_file | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| torch.backends.cuda.preferred_linalg_library("cusolver") | |
| def get_low_rank_weight(x: torch.Tensor, rank: int, distribute_scale: bool, method: str, dtype: torch.dtype = torch.float32): | |
| def get_scale(s): | |
| if distribute_scale: | |
| sqrt_s = torch.sqrt(s) | |
| scale_a = torch.diag(sqrt_s) | |
| scale_b = torch.diag(sqrt_s) | |
| else: | |
| scale_a = torch.diag(s) | |
| scale_b = torch.eye(rank, dtype=torch.float32, device=s.device) | |
| return scale_a, scale_b | |
| if method == "svd": | |
| U, S, Vh = torch.linalg.svd(x, full_matrices=False) | |
| rank = min(rank, S.shape[0]) | |
| U = U[:, :rank] | |
| S = S[:rank] | |
| Vh = Vh[:rank, :] | |
| scale_a, scale_b = get_scale(S) | |
| lora_A = scale_b @ Vh | |
| lora_B = U @ scale_a | |
| elif method == "qr_svd": | |
| Q, R = torch.linalg.qr(x) | |
| U, S, Vh = torch.linalg.svd(R, full_matrices=False) | |
| rank = min(rank, S.shape[0]) | |
| U = Q @ U[:, :rank] | |
| S = S[:rank] | |
| Vh = Vh[:rank, :] | |
| scale_a, scale_b = get_scale(S) | |
| lora_A = scale_b @ Vh | |
| lora_B = U @ scale_a | |
| elif method == "randomized_svd": | |
| rand_matrix = torch.randn(x.shape[1], rank, device=x.device, dtype=x.dtype) | |
| Y = x @ rand_matrix | |
| Q, _ = torch.linalg.qr(Y) | |
| B = Q.T @ x | |
| U_tilde, S, Vh = torch.linalg.svd(B, full_matrices=False) | |
| rank = min(rank, S.shape[0]) | |
| U = Q @ U_tilde | |
| U = U[:, :rank] | |
| S = S[:rank] | |
| Vh = Vh[:rank, :] | |
| lora_A = scale_b @ Vh | |
| lora_B = U @ scale_a | |
| elif method == "cur": | |
| col_norms = torch.norm(x, dim=0) | |
| row_norms = torch.norm(x, dim=1) | |
| col_probs = col_norms / col_norms.sum() | |
| row_probs = row_norms / row_norms.sum() | |
| col_indices = torch.multinomial(col_probs, rank, replacement=False) | |
| row_indices = torch.multinomial(row_probs, rank, replacement=False) | |
| C = x[:, col_indices] | |
| R = x[row_indices, :] | |
| U = torch.linalg.pinv(C[row_indices, :]) @ x[row_indices, col_indices] @ torch.linalg.pinv(R[:, col_indices]) | |
| lora_A = C @ U | |
| lora_B = R | |
| return ( | |
| lora_A.contiguous().to(dtype=dtype), | |
| lora_B.contiguous().to(dtype=dtype), | |
| ) | |
| def main(args): | |
| low_rank_fn = torch.compile(get_low_rank_weight, mode="max-autotune-no-cudagraphs", dynamic=True) | |
| model1_state_dict = load_file(args.model1_path) | |
| model1_state_dict = convert_flux_transformer_checkpoint_to_diffusers(model1_state_dict, 19, 38, 3072, 4) | |
| model2_state_dict = load_file(args.model2_path) | |
| model2_state_dict = convert_flux_transformer_checkpoint_to_diffusers(model2_state_dict, 19, 38, 3072, 4) | |
| lora_state_dict = {} | |
| for key in model2_state_dict: | |
| if ( | |
| (not key.endswith(".weight")) or | |
| ("norm" in key) or | |
| (key not in model1_state_dict) | |
| ): | |
| continue | |
| diff = model2_state_dict[key].float() - model1_state_dict[key].float() | |
| print(f"Processing key: {key} {diff.shape}") | |
| diff = diff.cuda() | |
| lora_A, lora_B = low_rank_fn(diff, args.rank, args.distribute_scale, args.method, dtype=model2_state_dict[key].dtype) | |
| A_key = "transformer." + key.removesuffix(".weight") + ".lora_A.weight" | |
| B_key = "transformer." + key.removesuffix(".weight") + ".lora_B.weight" | |
| lora_state_dict[A_key] = lora_A.cpu() | |
| lora_state_dict[B_key] = lora_B.cpu() | |
| save_file(lora_state_dict, args.output_path) | |
| def get_args(): | |
| parser = argparse.ArgumentParser("Extract LoRA from difference between model2 and model1 weights") | |
| parser.add_argument("--model1_path", type=str, required=True) | |
| parser.add_argument("--model2_path", type=str, required=True) | |
| parser.add_argument("--output_path", type=str, required=True) | |
| parser.add_argument("--rank", type=int, default=256) | |
| parser.add_argument("--distribute_scale", action="store_true") | |
| parser.add_argument( | |
| "--method", | |
| type=str, | |
| choices=["svd", "qr_svd", "randomized_svd", "cur"], | |
| default="svd", | |
| ) | |
| return parser.parse_args() | |
| # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; | |
| # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation | |
| def swap_scale_shift(weight): | |
| shift, scale = weight.chunk(2, dim=0) | |
| new_weight = torch.cat([scale, shift], dim=0) | |
| return new_weight | |
| def convert_flux_transformer_checkpoint_to_diffusers( | |
| original_state_dict, num_layers, num_single_layers, inner_dim, mlp_ratio=4.0 | |
| ): | |
| converted_state_dict = {} | |
| ## time_text_embed.timestep_embedder <- time_in | |
| converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop( | |
| "time_in.in_layer.weight" | |
| ) | |
| converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop( | |
| "time_in.in_layer.bias" | |
| ) | |
| converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop( | |
| "time_in.out_layer.weight" | |
| ) | |
| converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop( | |
| "time_in.out_layer.bias" | |
| ) | |
| ## time_text_embed.text_embedder <- vector_in | |
| converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = original_state_dict.pop( | |
| "vector_in.in_layer.weight" | |
| ) | |
| converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = original_state_dict.pop( | |
| "vector_in.in_layer.bias" | |
| ) | |
| converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = original_state_dict.pop( | |
| "vector_in.out_layer.weight" | |
| ) | |
| converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = original_state_dict.pop( | |
| "vector_in.out_layer.bias" | |
| ) | |
| # guidance | |
| has_guidance = any("guidance" in k for k in original_state_dict) | |
| if has_guidance: | |
| converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = original_state_dict.pop( | |
| "guidance_in.in_layer.weight" | |
| ) | |
| converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = original_state_dict.pop( | |
| "guidance_in.in_layer.bias" | |
| ) | |
| converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = original_state_dict.pop( | |
| "guidance_in.out_layer.weight" | |
| ) | |
| converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = original_state_dict.pop( | |
| "guidance_in.out_layer.bias" | |
| ) | |
| # context_embedder | |
| converted_state_dict["context_embedder.weight"] = original_state_dict.pop("txt_in.weight") | |
| converted_state_dict["context_embedder.bias"] = original_state_dict.pop("txt_in.bias") | |
| # x_embedder | |
| converted_state_dict["x_embedder.weight"] = original_state_dict.pop("img_in.weight") | |
| converted_state_dict["x_embedder.bias"] = original_state_dict.pop("img_in.bias") | |
| # double transformer blocks | |
| for i in range(num_layers): | |
| block_prefix = f"transformer_blocks.{i}." | |
| # norms. | |
| ## norm1 | |
| converted_state_dict[f"{block_prefix}norm1.linear.weight"] = original_state_dict.pop( | |
| f"double_blocks.{i}.img_mod.lin.weight" | |
| ) | |
| converted_state_dict[f"{block_prefix}norm1.linear.bias"] = original_state_dict.pop( | |
| f"double_blocks.{i}.img_mod.lin.bias" | |
| ) | |
| ## norm1_context | |
| converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = original_state_dict.pop( | |
| f"double_blocks.{i}.txt_mod.lin.weight" | |
| ) | |
| converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = original_state_dict.pop( | |
| f"double_blocks.{i}.txt_mod.lin.bias" | |
| ) | |
| # Q, K, V | |
| sample_q, sample_k, sample_v = torch.chunk( | |
| original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0 | |
| ) | |
| context_q, context_k, context_v = torch.chunk( | |
| original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0 | |
| ) | |
| sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( | |
| original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0 | |
| ) | |
| context_q_bias, context_k_bias, context_v_bias = torch.chunk( | |
| original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0 | |
| ) | |
| converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q]) | |
| converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias]) | |
| converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k]) | |
| converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias]) | |
| converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v]) | |
| converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias]) | |
| converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q]) | |
| converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias]) | |
| converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k]) | |
| converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias]) | |
| converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v]) | |
| converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias]) | |
| # qk_norm | |
| converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop( | |
| f"double_blocks.{i}.img_attn.norm.query_norm.scale" | |
| ) | |
| converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop( | |
| f"double_blocks.{i}.img_attn.norm.key_norm.scale" | |
| ) | |
| converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop( | |
| f"double_blocks.{i}.txt_attn.norm.query_norm.scale" | |
| ) | |
| converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop( | |
| f"double_blocks.{i}.txt_attn.norm.key_norm.scale" | |
| ) | |
| # ff img_mlp | |
| converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop( | |
| f"double_blocks.{i}.img_mlp.0.weight" | |
| ) | |
| converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop( | |
| f"double_blocks.{i}.img_mlp.0.bias" | |
| ) | |
| converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop( | |
| f"double_blocks.{i}.img_mlp.2.weight" | |
| ) | |
| converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop( | |
| f"double_blocks.{i}.img_mlp.2.bias" | |
| ) | |
| converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = original_state_dict.pop( | |
| f"double_blocks.{i}.txt_mlp.0.weight" | |
| ) | |
| converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = original_state_dict.pop( | |
| f"double_blocks.{i}.txt_mlp.0.bias" | |
| ) | |
| converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = original_state_dict.pop( | |
| f"double_blocks.{i}.txt_mlp.2.weight" | |
| ) | |
| converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = original_state_dict.pop( | |
| f"double_blocks.{i}.txt_mlp.2.bias" | |
| ) | |
| # output projections. | |
| converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop( | |
| f"double_blocks.{i}.img_attn.proj.weight" | |
| ) | |
| converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop( | |
| f"double_blocks.{i}.img_attn.proj.bias" | |
| ) | |
| converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = original_state_dict.pop( | |
| f"double_blocks.{i}.txt_attn.proj.weight" | |
| ) | |
| converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = original_state_dict.pop( | |
| f"double_blocks.{i}.txt_attn.proj.bias" | |
| ) | |
| # single transformer blocks | |
| for i in range(num_single_layers): | |
| block_prefix = f"single_transformer_blocks.{i}." | |
| # norm.linear <- single_blocks.0.modulation.lin | |
| converted_state_dict[f"{block_prefix}norm.linear.weight"] = original_state_dict.pop( | |
| f"single_blocks.{i}.modulation.lin.weight" | |
| ) | |
| converted_state_dict[f"{block_prefix}norm.linear.bias"] = original_state_dict.pop( | |
| f"single_blocks.{i}.modulation.lin.bias" | |
| ) | |
| # Q, K, V, mlp | |
| mlp_hidden_dim = int(inner_dim * mlp_ratio) | |
| split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim) | |
| q, k, v, mlp = torch.split(original_state_dict.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0) | |
| q_bias, k_bias, v_bias, mlp_bias = torch.split( | |
| original_state_dict.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0 | |
| ) | |
| converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q]) | |
| converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias]) | |
| converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k]) | |
| converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias]) | |
| converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v]) | |
| converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias]) | |
| converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp]) | |
| converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias]) | |
| # qk norm | |
| converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop( | |
| f"single_blocks.{i}.norm.query_norm.scale" | |
| ) | |
| converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop( | |
| f"single_blocks.{i}.norm.key_norm.scale" | |
| ) | |
| # output projections. | |
| converted_state_dict[f"{block_prefix}proj_out.weight"] = original_state_dict.pop( | |
| f"single_blocks.{i}.linear2.weight" | |
| ) | |
| converted_state_dict[f"{block_prefix}proj_out.bias"] = original_state_dict.pop( | |
| f"single_blocks.{i}.linear2.bias" | |
| ) | |
| converted_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight") | |
| converted_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias") | |
| converted_state_dict["norm_out.linear.weight"] = swap_scale_shift( | |
| original_state_dict.pop("final_layer.adaLN_modulation.1.weight") | |
| ) | |
| converted_state_dict["norm_out.linear.bias"] = swap_scale_shift( | |
| original_state_dict.pop("final_layer.adaLN_modulation.1.bias") | |
| ) | |
| return converted_state_dict | |
| if __name__ == "__main__": | |
| args = get_args() | |
| main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment