Last active
September 17, 2021 03:44
-
-
Save y011d4/eb16b8a719d744495929577bb825ece7 to your computer and use it in GitHub Desktop.
solver for smart_cryptooo chall in defcon quals
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import binascii | |
| import itertools | |
| import struct | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| # configurable stuff | |
| bits = 64 | |
| bunch_size = 16 | |
| m_bits = bits | |
| m_bytes = bits // 8 | |
| k_bits = bits | |
| c_bits = bits | |
| pad = "same" | |
| def decode_message(a, threshold=0): | |
| """ | |
| Decodes ML output into bytes. | |
| """ | |
| i = int("".join("0" if b < threshold else "1" for b in a), 2) | |
| return binascii.unhexlify(hex(i)[2:].rjust(m_bytes * 2, "0")) | |
| def encode_message(m): | |
| """ | |
| Encodes bytes into ML inputs. | |
| """ | |
| n = int(binascii.hexlify(m).ljust(m_bytes * 2, b"0"), 16) | |
| encoded = [(-1 if b == "0" else 1) for b in bin(n)[2:].rjust(m_bits, "0")] | |
| assert decode_message(encoded) == m | |
| return encoded | |
| # prepare data | |
| with open("./https___oooverflow.io_philosophy.html-secret.enc", "rb") as f: | |
| enc = f.read() | |
| with open("./philosophy.html", "rb") as f: | |
| html = f.read() | |
| assert len(enc) % 512 == 0 | |
| enc_float = [ | |
| struct.unpack("64d", enc[i : i + 512]) for i in range(0, len(enc), 512) | |
| ] # (3274, 64) | |
| enc_float = [ | |
| enc_float[i : i + 17] for i in range(0, len(enc_float), 17) | |
| ] # (193, 17, 64) [-1] だけ (17, 64) ではなく (10, 64) | |
| enc_msg = sum([e[:16] for e in enc_float], []) # (3082, 64) | |
| enc_msg = np.array(enc_msg, dtype=np.float32) | |
| enc_key = sum([[e[16]] * 16 for e in enc_float if len(e) == 17], []) | |
| enc_key = np.array(enc_key, dtype=np.float32) | |
| msg = [ | |
| encode_message(html[i : i + m_bytes].ljust(m_bytes)) | |
| for i in range(0, len(html), m_bytes) | |
| ] | |
| msg = np.array(msg, dtype=np.float32) | |
| class Model(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.keys = nn.Embedding(200, 64) | |
| self.linear0 = nn.Linear(m_bits + k_bits, m_bits + k_bits) | |
| self.conv0 = nn.Conv1d(1, 2, 4, 1, padding=1, padding_mode="replicate") | |
| self.conv1 = nn.Conv1d(2, 4, 2, 2, padding=1, padding_mode="replicate") | |
| self.conv2 = nn.Conv1d(4, 4, 1, 1, padding_mode="replicate") | |
| self.conv3 = nn.Conv1d(4, 1, 1, 1, padding_mode="replicate") | |
| def __call__(self, alice_plain, alice_key, zero_one): | |
| x = torch.cat([alice_plain, alice_key], dim=1) | |
| h = torch.sigmoid(self.linear0(x)) | |
| h = h.reshape(h.shape[0], 1, h.shape[1]) | |
| h = torch.sigmoid(self.conv0(h)) | |
| h = torch.sigmoid(self.conv1(h)) | |
| h = torch.sigmoid(self.conv2(h)) | |
| h = torch.sigmoid(self.conv3(h)) | |
| h = h.squeeze(1) | |
| if zero_one: | |
| return h | |
| else: | |
| return 2 * h - 1 | |
| def get_keys(self, alice_key_idx): | |
| return self.keys(alice_key_idx) | |
| device = torch.device("cuda:0") | |
| bob_model = Model() | |
| bob_model = bob_model.to(device) | |
| abe_optim = torch.optim.AdamW([{"params": bob_model.parameters()}], lr=5e-4) | |
| msg_tensor = torch.from_numpy(msg).type(torch.float32).to(device) | |
| enc_msg_tensor = torch.from_numpy(enc_msg).type(torch.float32).to(device) | |
| enc_key_tensor = torch.from_numpy(enc_key).type(torch.float32).to(device) | |
| batch_size = 4096 | |
| best_loss = 1e5 | |
| ### train (stop Ctrl+C if loss is relatively small (< 1.0)### | |
| n = 101 | |
| try: | |
| print("=" * 80) | |
| print(n) | |
| for step in itertools.count(): | |
| bob_model.train() | |
| used_idx = np.arange(16 * n) | |
| alice_plain = msg_tensor[used_idx].detach() | |
| alice_key = bob_model.get_keys(torch.from_numpy(used_idx // 16).to(device)) | |
| bob_cipher = enc_msg_tensor[used_idx].detach() | |
| bob_plain = bob_model(bob_cipher, alice_key, zero_one=True) | |
| dec_loss = nn.BCELoss(reduction="sum")(bob_plain, (alice_plain + 1.0) / 2) | |
| if step % 1 == 0: | |
| key_dec = bob_model( | |
| enc_key_tensor[: 16 * n : 16], | |
| bob_model.get_keys(torch.arange(n).to(device)), | |
| zero_one=False, | |
| ) | |
| key_dec = torch.where( | |
| key_dec > 0.0, | |
| torch.ones(key_dec.shape).to(device), | |
| torch.ones(key_dec.shape).to(device) * -1, | |
| ) | |
| key_loss = torch.sum( | |
| torch.abs( | |
| key_dec[:-1].detach() | |
| - bob_model.get_keys( | |
| torch.arange(1, 1 + len(key_dec) - 1).to(device) | |
| ) | |
| ) | |
| ) | |
| key_loss += torch.sum( | |
| torch.min( | |
| torch.abs( | |
| bob_model.get_keys(torch.arange(1).to(device)) | |
| - torch.ones((1, 64)).to(device) | |
| ), | |
| torch.abs( | |
| bob_model.get_keys(torch.arange(1).to(device)) | |
| - torch.ones((1, 64)).to(device) * -1 | |
| ), | |
| ) | |
| ) | |
| else: | |
| key_loss = dec_loss.new_zeros(()) | |
| loss = dec_loss + key_loss | |
| abe_optim.zero_grad() | |
| loss.backward() | |
| abe_optim.step() | |
| if step % 100 == 0: | |
| if loss.item() < best_loss: | |
| torch.save( | |
| bob_model, "./model.pt", | |
| ) | |
| best_loss = loss.item() | |
| print(step, loss.item(), dec_loss.item(), key_loss.item()) | |
| if step % 1000 == 0: | |
| out = bob_model( | |
| enc_msg_tensor[: len(enc_msg_tensor) // 16 * 16], | |
| bob_model.get_keys( | |
| torch.repeat_interleave( | |
| torch.arange(len(enc_msg_tensor) // 16), 16 | |
| ).to(device) | |
| ), | |
| zero_one=False, | |
| ) | |
| # check whether html page could be reconstructed | |
| out_msg_first = b"".join(decode_message(m) for m in out[: 16 * 32]) | |
| print(out_msg_first[:512]) | |
| except KeyboardInterrupt: | |
| print("end") | |
| ### decrypt ### | |
| bob_model = torch.load("./model.pt") | |
| n = 100 | |
| m = 10 | |
| key_pred = bob_model.get_keys(torch.arange(n, n + 1).to(device)) | |
| for n in range(n, n + m): | |
| key_pred = bob_model(enc_key_tensor[16 * n : 16 * n + 1], key_pred, zero_one=False) | |
| out = ( | |
| bob_model( | |
| enc_msg_tensor[16 * (n + 1) : 16 * (n + 2)], | |
| torch.repeat_interleave(key_pred, 16, dim=0), | |
| zero_one=False, | |
| ) | |
| .detach() | |
| .cpu() | |
| .numpy() | |
| ) | |
| print(n, b"".join([decode_message(o) for o in out])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment