import hashlib
import base64
import socket
import struct
import ssl
import errno
import codecs
from collections import deque
from select import select
import sys
VER = sys.version_info[0]
if VER >= 3:
from http.server import BaseHTTPRequestHandler # pylint: disable=import-error
from io import StringIO, BytesIO
unicode = str # pylint: disable=redefined-builtin
else:
from BaseHTTPServer import BaseHTTPRequestHandler # pylint: disable=import-error
from StringIO import StringIO # pylint: disable=import-error
__all__ = [
'WebSocket',
'WebSocketServer'
]
_VALID_STATUS_CODES = [1000, 1001, 1002, 1003, 1007, 1008, 1009, 1010, 1011, 3000, 3999, 4000, 4999]
HANDSHAKE_STR = (
'HTTP/1.1 101 Switching Protocols\r\n'
'Upgrade: WebSocket\r\n'
'Connection: Upgrade\r\n'
'Sec-WebSocket-Accept: %(acceptstr)s\r\n\r\n'
)
GUID_STR = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
STREAM = 0x0
TEXT = 0x1
BINARY = 0x2
CLOSE = 0x8
PING = 0x9
PONG = 0xA
HEADERB1 = 1
HEADERB2 = 3
LENGTHSHORT = 4
LENGTHLONG = 5
MASK = 6
PAYLOAD = 7
MAXHEADER = 65536
MAXPAYLOAD = 33554432
def _check_unicode(val):
if VER >= 3:
return isinstance(val, str)
return isinstance(val, unicode)
class HTTPRequest(BaseHTTPRequestHandler):
def __init__(self, request_text): # pylint: disable=super-init-not-called
if VER >= 3:
self.rfile = BytesIO(request_text)
else:
self.rfile = StringIO(request_text)
self.raw_requestline = self.rfile.readline()
self.error_code = self.error_message = None
self.parse_request()
class WebSocket(object): # pylint: disable=too-many-instance-attributes
def __init__(self, server, sock, address):
self.server = server
self.client = sock
self.address = address
self.handshaked = False
self.headerbuffer = bytearray()
self.headertoread = 2048
self.fin = 0
self.data = bytearray()
self.opcode = 0
self.hasmask = 0
self.maskarray = None
self.length = 0
self.lengtharray = None
self.index = 0
self.request = None
self.usingssl = False
self.frag_start = False
self.frag_type = BINARY
self.frag_buffer = None
self.frag_decoder = codecs.getincrementaldecoder('utf-8')(errors='strict')
self.closed = False
self.sendq = deque()
self.state = HEADERB1
# restrict the size of header and payload for security reasons
self.maxheader = MAXHEADER
self.maxpayload = MAXPAYLOAD
def handle(self):
"""
Called when websocket frame is received.
To access the frame data call self.data.
If the frame is Text then self.data is a unicode object.
If the frame is Binary then self.data is a bytearray object.
"""
pass
def connected(self):
"""
Called when a websocket client connects to the server.
"""
pass
def handle_close(self):
"""
Called when a websocket server gets a Close frame from a client.
"""
pass
def _handle_packet(self): # pylint: disable=too-many-branches, too-many-statements
if self.opcode == CLOSE:
pass
elif self.opcode == STREAM:
pass
elif self.opcode == TEXT:
pass
elif self.opcode == BINARY:
pass
elif self.opcode == PONG or self.opcode == PING:
if len(self.data) > 125:
raise Exception('control frame length can not be > 125')
else:
# unknown or reserved opcode so just close
raise Exception('unknown opcode')
if self.opcode == CLOSE:
status = 1000
reason = u''
length = len(self.data)
if length == 0:
pass
elif length >= 2:
status = struct.unpack_from('!H', self.data[:2])[0]
reason = self.data[2:]
if status not in _VALID_STATUS_CODES:
status = 1002
if reason:
try:
reason = reason.decode('utf8', errors='strict')
except Exception: # pylint: disable=broad-except
status = 1002
else:
status = 1002
self.close(status, reason)
return
elif self.fin == 0:
if self.opcode != STREAM:
if self.opcode == PING or self.opcode == PONG:
raise Exception('control messages can not be fragmented')
self.frag_type = self.opcode
self.frag_start = True
self.frag_decoder.reset()
if self.frag_type == TEXT:
self.frag_buffer = []
utf_str = self.frag_decoder.decode(self.data, final=False)
if utf_str:
self.frag_buffer.append(utf_str)
else:
self.frag_buffer = bytearray()
self.frag_buffer.extend(self.data)
else:
if self.frag_start is False:
raise Exception('fragmentation protocol error')
if self.frag_type == TEXT:
utf_str = self.frag_decoder.decode(self.data, final=False)
if utf_str:
self.frag_buffer.append(utf_str)
else:
self.frag_buffer.extend(self.data)
else:
if self.opcode == STREAM:
if self.frag_start is False:
raise Exception('fragmentation protocol error')
if self.frag_type == TEXT:
utf_str = self.frag_decoder.decode(self.data, final=True)
self.frag_buffer.append(utf_str)
self.data = u''.join(self.frag_buffer)
else:
self.frag_buffer.extend(self.data)
self.data = self.frag_buffer
self.handle()
self.frag_decoder.reset()
self.frag_type = BINARY
self.frag_start = False
self.frag_buffer = None
elif self.opcode == PING:
self._send_message(False, PONG, self.data)
elif self.opcode == PONG:
pass
else:
if self.frag_start is True:
raise Exception('fragmentation protocol error')
if self.opcode == TEXT:
try:
self.data = self.data.decode('utf8', errors='strict')
except Exception:
raise Exception('invalid utf-8 payload')
self.handle()
def _handle_data(self):
# do the HTTP header and handshake
if self.handshaked is False:
data = self.client.recv(self.headertoread)
if not data:
raise Exception('remote socket closed')
else:
# accumulate
self.headerbuffer.extend(data)
if len(self.headerbuffer) >= self.maxheader:
raise Exception('header exceeded allowable size')
# indicates end of HTTP header
if b'\r\n\r\n' in self.headerbuffer:
self.request = HTTPRequest(self.headerbuffer)
# handshake rfc 6455
try:
key = self.request.headers['Sec-WebSocket-Key']
k = key.encode('ascii') + GUID_STR.encode('ascii')
k_s = base64.b64encode(hashlib.sha1(k).digest()).decode('ascii')
hs = HANDSHAKE_STR % {'acceptstr': k_s}
self.sendq.append((BINARY, hs.encode('ascii')))
self.handshaked = True
self.connected()
except Exception as e:
raise Exception('handshake failed: {}'.format(e))
# else do normal data
else:
data = self.client.recv(16384)
if not data:
raise Exception("remote socket closed")
if VER >= 3:
for d in data:
self._parse_message(d)
else:
for d in data:
self._parse_message(ord(d))
def close(self, status=1000, reason=u''):
"""
Send Close frame to the client. The underlying socket is only closed
when the client acknowledges the Close frame.
status is the closing identifier.
reason is the reason for the close.
"""
try:
if self.closed is False:
close_msg = bytearray()
close_msg.extend(struct.pack("!H", status))
if _check_unicode(reason):
close_msg.extend(reason.encode('utf-8'))
else:
close_msg.extend(reason)
self._send_message(False, CLOSE, close_msg)
finally:
self.closed = True
def _send_buffer(self, buff, send_all=False):
size = len(buff)
tosend = size
already_sent = 0
while tosend > 0:
try:
# i should be able to send a bytearray
sent = self.client.send(buff[already_sent:])
if sent == 0:
raise RuntimeError('socket connection broken')
already_sent += sent
tosend -= sent
except socket.error as e:
# if we have full buffers then wait for them to drain and try again
if e.errno in [errno.EAGAIN, errno.EWOULDBLOCK]:
if send_all:
continue
return buff[already_sent:]
raise e
return None
def send_fragment_start(self, data):
"""
Send the start of a data fragment stream to a websocket client.
Subsequent data should be sent using sendFragment().
A fragment stream is completed when sendFragmentEnd() is called.
If data is a unicode object then the frame is sent as Text.
If the data is a bytearray object then the frame is sent as Binary.
"""
opcode = BINARY
if _check_unicode(data):
opcode = TEXT
self._send_message(True, opcode, data)
def send_fragment(self, data):
"""
see sendFragmentStart()
If data is a unicode object then the frame is sent as Text.
If the data is a bytearray object then the frame is sent as Binary.
"""
self._send_message(True, STREAM, data)
def send_fragment_end(self, data):
"""
see sendFragmentEnd()
If data is a unicode object then the frame is sent as Text.
If the data is a bytearray object then the frame is sent as Binary.
"""
self._send_message(False, STREAM, data)
def send_message(self, data):
"""
Send websocket data frame to the client.
If data is a unicode object then the frame is sent as Text.
If the data is a bytearray object then the frame is sent as Binary.
"""
opcode = BINARY
if _check_unicode(data):
opcode = TEXT
self._send_message(False, opcode, data)
def _send_message(self, fin, opcode, data):
payload = bytearray()
b1 = 0
b2 = 0
if fin is False:
b1 |= 0x80
b1 |= opcode
if _check_unicode(data):
data = data.encode('utf-8')
length = len(data)
payload.append(b1)
if length <= 125:
b2 |= length
payload.append(b2)
elif 126 <= length <= 65535:
b2 |= 126
payload.append(b2)
payload.extend(struct.pack("!H", length))
else:
b2 |= 127
payload.append(b2)
payload.extend(struct.pack("!Q", length))
if length > 0:
payload.extend(data)
self.sendq.append((opcode, payload))
def _parse_message(self, byte): # pylint: disable=too-many-branches, too-many-statements
# read in the header
if self.state == HEADERB1:
self.fin = byte & 0x80
self.opcode = byte & 0x0F
self.state = HEADERB2
self.index = 0
self.length = 0
self.lengtharray = bytearray()
self.data = bytearray()
rsv = byte & 0x70
if rsv != 0:
raise Exception('RSV bit must be 0')
elif self.state == HEADERB2:
mask = byte & 0x80
length = byte & 0x7F
if self.opcode == PING and length > 125:
raise Exception('ping packet is too large')
self.hasmask = mask == 128
if length <= 125:
self.length = length
# if we have a mask we must read it
if self.hasmask is True:
self.maskarray = bytearray()
self.state = MASK
else:
# if there is no mask and no payload we are done
if self.length <= 0:
try:
self._handle_packet()
finally:
self.state = HEADERB1
self.data = bytearray()
# we have no mask and some payload
else:
# self.index = 0
self.data = bytearray()
self.state = PAYLOAD
elif length == 126:
self.lengtharray = bytearray()
self.state = LENGTHSHORT
elif length == 127:
self.lengtharray = bytearray()
self.state = LENGTHLONG
elif self.state == LENGTHSHORT:
self.lengtharray.append(byte)
if len(self.lengtharray) > 2:
raise Exception('short length exceeded allowable size')
if len(self.lengtharray) == 2:
self.length = struct.unpack_from('!H', self.lengtharray)[0]
if self.hasmask is True:
self.maskarray = bytearray()
self.state = MASK
else:
# if there is no mask and no payload we are done
if self.length <= 0:
try:
self._handle_packet()
finally:
self.state = HEADERB1
self.data = bytearray()
# we have no mask and some payload
else:
# self.index = 0
self.data = bytearray()
self.state = PAYLOAD
elif self.state == LENGTHLONG:
self.lengtharray.append(byte)
if len(self.lengtharray) > 8:
raise Exception('long length exceeded allowable size')
if len(self.lengtharray) == 8:
self.length = struct.unpack_from('!Q', self.lengtharray)[0]
if self.hasmask is True:
self.maskarray = bytearray()
self.state = MASK
else:
# if there is no mask and no payload we are done
if self.length <= 0:
try:
self._handle_packet()
finally:
self.state = HEADERB1
self.data = bytearray()
# we have no mask and some payload
else:
# self.index = 0
self.data = bytearray()
self.state = PAYLOAD
# MASK STATE
elif self.state == MASK:
self.maskarray.append(byte)
if len(self.maskarray) > 4:
raise Exception('mask exceeded allowable size')
if len(self.maskarray) == 4:
# if there is no mask and no payload we are done
if self.length <= 0:
try:
self._handle_packet()
finally:
self.state = HEADERB1
self.data = bytearray()
# we have no mask and some payload
else:
# self.index = 0
self.data = bytearray()
self.state = PAYLOAD
# PAYLOAD STATE
elif self.state == PAYLOAD:
if self.hasmask is True:
self.data.append(byte ^ self.maskarray[self.index % 4])
else:
self.data.append(byte)
# if length exceeds allowable size then we except and remove the connection
if len(self.data) >= self.maxpayload:
raise Exception('payload exceeded allowable size')
# check if we have processed length bytes; if so we are done
if (self.index + 1) == self.length:
try:
self._handle_packet()
finally:
# self.index = 0
self.state = HEADERB1
self.data = bytearray()
else:
self.index += 1
class WebSocketServer(object):
request_queue_size = 5
# pylint: disable=too-many-arguments
def __init__(self, host, port, websocketclass, certfile=None, keyfile=None,
ssl_version=ssl.PROTOCOL_TLSv1, select_interval=0.1):
self.websocketclass = websocketclass
self.serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.serversocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.serversocket.bind((host, port))
self.serversocket.listen(self.request_queue_size)
self.selectInterval = select_interval
self.connections = {}
self.listeners = [self.serversocket]
self._using_ssl = bool(certfile and keyfile)
if self._using_ssl:
self.context = ssl.SSLContext(ssl_version)
self.context.load_cert_chain(certfile, keyfile)
def _decorate_socket(self, sock): # pylint: disable=no-self-use
if self._using_ssl:
return self.context.wrap_socket(sock, server_side=True)
return sock
def _construct_websocket(self, sock, address):
ws = self.websocketclass(self, sock, address)
if self._using_ssl:
ws.usingssl = True
return ws
def close(self):
self.serversocket.close()
for desc, conn in self.connections.items(): # pylint: disable=unused-variable
conn.close()
self._handle_close(conn)
def _handle_close(self, client): # pylint: disable=no-self-use
client.client.close()
# only call handle_close when we have a successful websocket connection
if client.handshaked:
try:
client.handle_close()
except Exception: # pylint: disable=broad-except
pass
def handle_request(self): # pylint: disable=too-many-branches, too-many-statements, too-many-locals
writers = []
for fileno in self.listeners:
if fileno == self.serversocket:
continue
client = self.connections[fileno]
if client.sendq:
writers.append(fileno)
if self.selectInterval:
r_list, w_list, x_list = select(self.listeners, writers, self.listeners, self.selectInterval)
else:
r_list, w_list, x_list = select(self.listeners, writers, self.listeners)
for ready in w_list:
client = self.connections[ready]
try:
while client.sendq:
opcode, payload = client.sendq.popleft()
remaining = client._send_buffer(payload) # pylint: disable=protected-access
if remaining is not None:
client.sendq.appendleft((opcode, remaining))
break
else:
if opcode == CLOSE:
raise Exception('received client close')
except Exception: # pylint: disable=broad-except
self._handle_close(client)
del self.connections[ready]
self.listeners.remove(ready)
for ready in r_list:
if ready == self.serversocket:
sock = None
try:
sock, address = self.serversocket.accept()
newsock = self._decorate_socket(sock)
newsock.setblocking(0) # pylint: disable=no-member
fileno = newsock.fileno() # pylint: disable=no-member
self.connections[fileno] = self._construct_websocket(newsock, address)
self.listeners.append(fileno)
except Exception: # pylint: disable=broad-except
if sock is not None:
sock.close()
else:
if ready not in self.connections:
continue
client = self.connections[ready]
try:
client._handle_data() # pylint: disable=protected-access
except Exception: # pylint: disable=broad-except
self._handle_close(client)
del self.connections[ready]
self.listeners.remove(ready)
for failed in x_list:
if failed == self.serversocket:
self.close()
raise Exception('server socket failed')
else:
if failed not in self.connections:
continue
client = self.connections[failed]
self._handle_close(client)
del self.connections[failed]
self.listeners.remove(failed)
def serve_forever(self):
while True:
self.handle_request()