Skip to content

Instantly share code, notes, and snippets.

@y011d4
Last active September 17, 2021 03:44
Show Gist options
  • Select an option

  • Save y011d4/eb16b8a719d744495929577bb825ece7 to your computer and use it in GitHub Desktop.

Select an option

Save y011d4/eb16b8a719d744495929577bb825ece7 to your computer and use it in GitHub Desktop.
solver for smart_cryptooo chall in defcon quals
import binascii
import itertools
import struct
import numpy as np
import torch
import torch.nn as nn
# configurable stuff
bits = 64
bunch_size = 16
m_bits = bits
m_bytes = bits // 8
k_bits = bits
c_bits = bits
pad = "same"
def decode_message(a, threshold=0):
"""
Decodes ML output into bytes.
"""
i = int("".join("0" if b < threshold else "1" for b in a), 2)
return binascii.unhexlify(hex(i)[2:].rjust(m_bytes * 2, "0"))
def encode_message(m):
"""
Encodes bytes into ML inputs.
"""
n = int(binascii.hexlify(m).ljust(m_bytes * 2, b"0"), 16)
encoded = [(-1 if b == "0" else 1) for b in bin(n)[2:].rjust(m_bits, "0")]
assert decode_message(encoded) == m
return encoded
# prepare data
with open("./https___oooverflow.io_philosophy.html-secret.enc", "rb") as f:
enc = f.read()
with open("./philosophy.html", "rb") as f:
html = f.read()
assert len(enc) % 512 == 0
enc_float = [
struct.unpack("64d", enc[i : i + 512]) for i in range(0, len(enc), 512)
] # (3274, 64)
enc_float = [
enc_float[i : i + 17] for i in range(0, len(enc_float), 17)
] # (193, 17, 64) [-1] だけ (17, 64) ではなく (10, 64)
enc_msg = sum([e[:16] for e in enc_float], []) # (3082, 64)
enc_msg = np.array(enc_msg, dtype=np.float32)
enc_key = sum([[e[16]] * 16 for e in enc_float if len(e) == 17], [])
enc_key = np.array(enc_key, dtype=np.float32)
msg = [
encode_message(html[i : i + m_bytes].ljust(m_bytes))
for i in range(0, len(html), m_bytes)
]
msg = np.array(msg, dtype=np.float32)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.keys = nn.Embedding(200, 64)
self.linear0 = nn.Linear(m_bits + k_bits, m_bits + k_bits)
self.conv0 = nn.Conv1d(1, 2, 4, 1, padding=1, padding_mode="replicate")
self.conv1 = nn.Conv1d(2, 4, 2, 2, padding=1, padding_mode="replicate")
self.conv2 = nn.Conv1d(4, 4, 1, 1, padding_mode="replicate")
self.conv3 = nn.Conv1d(4, 1, 1, 1, padding_mode="replicate")
def __call__(self, alice_plain, alice_key, zero_one):
x = torch.cat([alice_plain, alice_key], dim=1)
h = torch.sigmoid(self.linear0(x))
h = h.reshape(h.shape[0], 1, h.shape[1])
h = torch.sigmoid(self.conv0(h))
h = torch.sigmoid(self.conv1(h))
h = torch.sigmoid(self.conv2(h))
h = torch.sigmoid(self.conv3(h))
h = h.squeeze(1)
if zero_one:
return h
else:
return 2 * h - 1
def get_keys(self, alice_key_idx):
return self.keys(alice_key_idx)
device = torch.device("cuda:0")
bob_model = Model()
bob_model = bob_model.to(device)
abe_optim = torch.optim.AdamW([{"params": bob_model.parameters()}], lr=5e-4)
msg_tensor = torch.from_numpy(msg).type(torch.float32).to(device)
enc_msg_tensor = torch.from_numpy(enc_msg).type(torch.float32).to(device)
enc_key_tensor = torch.from_numpy(enc_key).type(torch.float32).to(device)
batch_size = 4096
best_loss = 1e5
### train (stop Ctrl+C if loss is relatively small (< 1.0)###
n = 101
try:
print("=" * 80)
print(n)
for step in itertools.count():
bob_model.train()
used_idx = np.arange(16 * n)
alice_plain = msg_tensor[used_idx].detach()
alice_key = bob_model.get_keys(torch.from_numpy(used_idx // 16).to(device))
bob_cipher = enc_msg_tensor[used_idx].detach()
bob_plain = bob_model(bob_cipher, alice_key, zero_one=True)
dec_loss = nn.BCELoss(reduction="sum")(bob_plain, (alice_plain + 1.0) / 2)
if step % 1 == 0:
key_dec = bob_model(
enc_key_tensor[: 16 * n : 16],
bob_model.get_keys(torch.arange(n).to(device)),
zero_one=False,
)
key_dec = torch.where(
key_dec > 0.0,
torch.ones(key_dec.shape).to(device),
torch.ones(key_dec.shape).to(device) * -1,
)
key_loss = torch.sum(
torch.abs(
key_dec[:-1].detach()
- bob_model.get_keys(
torch.arange(1, 1 + len(key_dec) - 1).to(device)
)
)
)
key_loss += torch.sum(
torch.min(
torch.abs(
bob_model.get_keys(torch.arange(1).to(device))
- torch.ones((1, 64)).to(device)
),
torch.abs(
bob_model.get_keys(torch.arange(1).to(device))
- torch.ones((1, 64)).to(device) * -1
),
)
)
else:
key_loss = dec_loss.new_zeros(())
loss = dec_loss + key_loss
abe_optim.zero_grad()
loss.backward()
abe_optim.step()
if step % 100 == 0:
if loss.item() < best_loss:
torch.save(
bob_model, "./model.pt",
)
best_loss = loss.item()
print(step, loss.item(), dec_loss.item(), key_loss.item())
if step % 1000 == 0:
out = bob_model(
enc_msg_tensor[: len(enc_msg_tensor) // 16 * 16],
bob_model.get_keys(
torch.repeat_interleave(
torch.arange(len(enc_msg_tensor) // 16), 16
).to(device)
),
zero_one=False,
)
# check whether html page could be reconstructed
out_msg_first = b"".join(decode_message(m) for m in out[: 16 * 32])
print(out_msg_first[:512])
except KeyboardInterrupt:
print("end")
### decrypt ###
bob_model = torch.load("./model.pt")
n = 100
m = 10
key_pred = bob_model.get_keys(torch.arange(n, n + 1).to(device))
for n in range(n, n + m):
key_pred = bob_model(enc_key_tensor[16 * n : 16 * n + 1], key_pred, zero_one=False)
out = (
bob_model(
enc_msg_tensor[16 * (n + 1) : 16 * (n + 2)],
torch.repeat_interleave(key_pred, 16, dim=0),
zero_one=False,
)
.detach()
.cpu()
.numpy()
)
print(n, b"".join([decode_message(o) for o in out]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment