|
Programming Guru
Join Date: Aug 2005
Location: England
Posts: 1,499
Rep Power: 5 
|
Strongly encrypted stream protocol
I've created an secure stream protocol and library in Python that uses RSA/AES encryption. It's hopefully pretty simple to use, though I'm unsure about the current way I've gone about implementing it. On the off chance that this is useful to someone, or if someone has a suggestion on how to improve it, I've posted up the code.This protocol needs the Python Cryptography Toolkit.
A simple echo server-client using this library looks like this:
client.py import socket
import protocol
localhost = ''
sock = socket.socket()
sock.connect((localhost, 6000))
client = protocol.EncryptedClient(sock.recv, sock.sendall, sock.close)
client.write("Hello World")
print client.read(1024)
client.close() server.py import socket
import protocol
localhost = ''
sock = socket.socket()
sock.bind((localhost, 6000))
sock.listen(1)
conn, address = sock.accept()
server = protocol.EncryptedServer(conn.recv, conn.sendall, conn.close)
data = server.read(1024)
server.write(data)
server.close()
sock.close()
The protocol itself is pretty much as basic as you can get:
[client] PublicKey: <a public key as a string>
[server] CipherKey: <a cipher key encrypted with the public key>
<encrypted stream follows> The public key is RSA, and takes the form: <hex encoded n>.<hex encoded e>
The cipher key is AES, and takes the form: <hex encoded RSA encrypted byte string>
The encrypted stream are 16-byte chunks of information encrypted with the AES cipher key. When decrypted, the first byte specifies the length of the string contained in the following 15 bytes. Padding is done with Xs.
The protocol library itself is below:
protocol.py import binascii
from Crypto.Cipher import AES
from Crypto.PublicKey import RSA
from Crypto.Util.randpool import RandomPool
def hexify(o):
"Encode a number or a string into a hexidecimal representation."
if type(o) == int or type(o) == long:
return hex(o)[2:].rstrip('L').lower()
elif type(o) == str:
return binascii.hexlify(o)
def unhexify(o, type):
"Decode a hexidecimal coded string into a number or a string."
if type == int or type == long:
return int(o, 16)
elif type == str:
return binascii.unhexlify(o)
def chunks(s, size):
"Split a string into chunks of equal size."
for i in range(0, len(s), size):
yield s[i:i + size]
class BaseKey:
"Base key wrapper."
def __init__(self, key):
self._key = key
class RSAKey(BaseKey):
"Wrapper around Crypto.PublicKey.RSA."
@classmethod
def generate(cls, size = 1024):
"Generate a new private key."
return cls(RSA.generate(size, RandomPool().get_bytes))
def export(self):
"Export key to a hexidecimal string."
public_key = self._key.publickey()
return hexify(public_key.n) + "." + hexify(public_key.e)
@classmethod
def load(cls, key):
"Import a key from a string."
key_parts = tuple([unhexify(o, long) for o in key.split(".")])
return cls(RSA.construct(key_parts))
def encrypt(self, s):
"Encrypt a string."
return self._key.encrypt(s, '')[0]
def decrypt(self, s):
"Decrypt a string. Imported keys cannot decrypt."
return self._key.decrypt(s)
class AESKey(BaseKey):
"""
Wrapper around Crypto.Cipher.AES.
AES is a symmetric cipher, and as such must never transmit its key
plaintext. As such, the export and load functions need a key from an
asymmetric encryption algorithm that has 'encrypt' and 'decrypt' methods.
"""
@classmethod
def generate(cls, size = 32):
"Generate a new cipher key."
return cls(RandomPool().get_bytes(size))
def export(self, public_key):
"Export key to an encrypted hexidecimal string."
return hexify(public_key.encrypt(self._key))
@classmethod
def load(cls, cipher_key, private_key):
"Import a key from an encrypted hexidecimal string."
return cls(private_key.decrypt(unhexify(cipher_key, str)))
def encrypt(self, s):
"Encrypt a string."
return ''.join(AES.new(self._key).encrypt(self._serialize(c))
for c in chunks(s, 15))
def decrypt(self, s):
"""Decrypt a string. Imported keys cannot decrypt. String length
must be a multiple of 16."""
return ''.join(self._unserialize(AES.new(self._key).decrypt(c))
for c in chunks(s, 16))
@staticmethod
def _serialize(s, l = 15):
assert len(s) < 256
padding = 'X' * ((l - len(s)) % l)
return chr(len(s)) + s + padding
@staticmethod
def _unserialize(s):
assert len(s) <= 256
l = ord(s[0]) + 1
return s[1:l]
class ProtocolError(Exception):
"Raised when the protocol object encounters invalid input."
class BaseProtocol:
"""
Wraps the protocol in an object so that their use is transparent
to the user. All protocol objects have setup, read, write and close
functions.
"""
def __init__(self, read, write, close):
"Initiate object with functions for reading and writing."
self._read = read
self._write = write
self._close = close
self.setup()
def setup(self):
pass
def read(self, size):
return self._read(size)
def write(self, data):
return self._write(data)
def close(self):
return self._close()
class LineMixin:
"Mixin that adds readline and writeline functions to protocols."
buffer_size = 15
buffer = ""
def readline(self):
"Read a line."
while self.buffer.find('\n') == -1:
s = self.read(self.buffer_size)
if s == "":
break
self.buffer += s
line, self.buffer = self.buffer.split('\n', 1)
return line.rstrip('\r')
def writeline(self, line):
"Write a line."
self.write(line + "\r\n")
class AttributeProtocol(BaseProtocol, LineMixin):
"""
Protocol that sends and receives data via attributes. Each line sent
should contain an attribute. Attributes consist of a name and a value, and
are separated by colons.
e.g.
Password: secret
Access: OK
Get: file.txt
"""
def readattr(self, name = None):
"""
Read an attribute. If name is set, a ProtocolError is raised if
the expecpted name is not encountered.
"""
attr, value = [s.strip() for s in self.readline().split(":")]
if name == None:
return attr, value
elif attr == name:
return value
else:
raise ProtocolError, "Unexpected attribute: %s" % attr
def writeattr(self, name, value):
"Write an attibute"
self.writeline(name + ": " + value)
class EncryptedBase(BaseProtocol):
"""
Encrypted Protocol. Subclasses need to provide a cipher_key attribute upon
class initialisation. This is usually done by overloading the setup
function.
"""
def __init__(self, read, write, close,
PublicKey = RSAKey, CipherKey = AESKey):
self.PublicKey = PublicKey
self.CipherKey = CipherKey
self.handshake = AttributeProtocol(read, write, close)
BaseProtocol.__init__(self, read, write, close)
def read(self, size):
return self.cipher_key.decrypt(BaseProtocol.read(self, size))
def write(self, s):
return BaseProtocol.write(self, self.cipher_key.encrypt(s))
class EncryptedServer(EncryptedBase):
"Encrypted server protocol"
def setup(self):
"Setup secure connection with client."
self.foreign_key = self.PublicKey.load(
self.handshake.readattr("PublicKey") )
self.cipher_key = self.CipherKey.generate()
self.handshake.writeattr("CipherKey",
self.cipher_key.export(self.foreign_key))
class EncryptedClient(EncryptedBase):
"Encrypted client protocol"
def __init__(self, read, write, close, public_key = None,
PublicKey = RSAKey, CipherKey = AESKey):
self.public_key = public_key
if self.public_key == None:
self.public_key = PublicKey.generate()
EncryptedBase.__init__(self, read, write, close, PublicKey, CipherKey)
def setup(self):
"Setup secure connection with client."
self.handshake.writeattr("PublicKey", self.public_key.export())
self.cipher_key = self.CipherKey.load(
self.handshake.readattr("CipherKey"), self.public_key)
|