import asyncore
import asynchat
import os
import socket
import mimetools
import mimetypes
import urlparse
import sys
import traceback
import cgi
import log
try:
    from cStringIO import StringIO
except ImportError:
    from StringIO import StringIO

MAX_REQUEST_LINE_SIZE = 4096

here = os.path.abspath(os.path.dirname(__file__))

# some internal handlers
class ResponseHandler(object):
    """ abstract base class for response handlers
    
        the next() method should return strings that, concatenated, form
        the full response (including status line!)

        raise StopIteration() from next() when the handler is done

        return None from next() to have the request continue without writing
        data to the channel (can be used to allow a thread to take over
        control, see test/test_httpserver.py's test_feeder() for an example)
    """
    def __init__(self, config, request, channel):
        self.config = config
        self.request = request
        self.channel = channel

    def next(self):
        raise StopIteration()

    def __iter__(self):
        return self

def file_handler_factory(fpath, content_type=None):
    if content_type is None:
        content_type = (mimetypes.guess_type(fpath)[0] or
                        'application/x-unknown')
    def handler_wrapper(config, request, channel):
        return FileHandler(fpath, content_type, config, request, channel)
    return handler_wrapper

class FileHandler(ResponseHandler):
    bufsize = 1024

    def __init__(self, fpath, content_type, config, request, handler):
        super(FileHandler, self).__init__(config, request, handler)
        self.fp = open(fpath)
        self.content_type = content_type
        self.fsize = os.path.getsize(fpath)
        self.header_sent = False

    def next(self):
        if not self.header_sent:
            self.header_sent = True
            ret = self.headers()
            return ret
        pos = self.fp.tell()
        bufsize = self.bufsize
        if pos >= self.fsize:
            raise StopIteration()
        if pos + bufsize > self.fsize:
            bufsize = self.fsize - pos
        return self.fp.read(bufsize)

    def headers(self):
        return '\r\n'.join((
            '%s 200 Ok' % (self.request.httpver,),
            'Content-Type: %s' % (self.content_type,),
            'Content-Length: %s' % (self.fsize,),
            '',
            '',
        ))

    def next_chunk(self):
        pos = self.fp.tell()
        bufsize = self.bufsize
        if pos >= self.fsize:
            raise StopIteration()
        if pos + bufsize > self.fsize:
            bufsize = self.fsize - pos
        return self.fp.read(bufsize)

    def __del__(self):
        self.fp.close()

class DirectoryHandler(FileHandler):
    def __init__(self, dirpath, webroot, config, request, channel):
        super(FileHandler, self).__init__(config, request, channel)
        if webroot.endswith('/'):
            webroot = webroot[:-1]
        self.headers, self.fp = self.process_request(request, dirpath, webroot)
        self.headers_sent = False

    def process_request(self, request, dirpath, webroot):
        path = request.path.split('?', 1)[0]
        if not path.startswith(webroot):
            return self.headers_404(), None
        path = path[len(webroot):]
        path = os.path.join(dirpath, path[1:])
        if not os.path.exists(path) or not os.path.isfile(path):
            return self.headers_404(), None
        return self.headers_200(path), open(path)
    
    def next(self):
        if not self.headers_sent:
            self.headers_sent = True
            return self.headers
        if self.fp is None:
            raise StopIteration()
        return self.next_chunk()

    def __del__(self):
        if self.fp is not None:
            self.fp.close()

    def headers_404(self):
        return '\r\n'.join((
            '%s 404 Not Found' % (self.request.httpver,),
            '',
            '',
        ))

    def headers_200(self, path):
        ct = mimetypes.guess_type(path)[0] or 'application/x-unknown'
        cl = os.path.getsize(path)
        self.fsize = cl
        return '\r\n'.join((
            '%s 200 Ok' % (self.request.httpver,),
            'Content-Type: %s' % (ct,),
            'Content-Length: %s' % (cl,),
            '',
            '',
        ))

class NotFoundHandler(ResponseHandler):
    data = (
        '%(httpver)s 404 Not Found',
        'Content-Type: text/plain',
        'Content-Length: %(bodylen)s',
        '',
        'Path %(path)s not found',
    )
    def __init__(self, config, request, channel):
        super(NotFoundHandler, self).__init__(config, request, channel)
        self.i = 0

        interdata = request.__dict__.copy()
        interdata['bodylen'] = len(self.data[-1] % interdata) + 2 # EOL
        self.iter = ((line % interdata) + '\r\n' for line in self.data)

    def next(self):
        return self.iter.next()

class MethodNotSupportedHandler(NotFoundHandler):
    data = (
        '%(httpver)s 405 Method Not Supported',
        'Content-Type: text/plain',
        'Content-Length: %(bodylen)s',
        '',
        'Method %(method)s not supported',
    )

class ErrorHandler(ResponseHandler):
    def __init__(self, config, request, channel, exc, e, tb):
        super(ErrorHandler, self).__init__(config, request, channel)
        tblines = ['%s - %s' % (exc, e)] + traceback.format_tb(tb)
        clen = sum([len(l) for l in tblines])
        self.data = (line + '\r\n' for line in [
            '%s 500 Server Error' % (request.httpver,),
            'Content-Type: text/plain',
            'Content-Length: %s' % (clen,),
            '',
        ] + tblines)

    def next(self):
        return self.data.next()

# a wrapper around request information
class Request(object):
    """ holds information about the request
    """
    def __init__(self):
        self.method = None # HTTP method e.g. 'GET'
        self.path = None # full path (including query etc.)
        self.httpver = None # string e.g. 'HTTP/1.1'
        self.headers = None # mimetools.Message
        self.body = None # iterator

        self._form = None

    @property
    def form(self):
        form = self._form
        if form is not None:
            return form
        query = urlparse.urlparse(self.path)[4]
        env = {
            'QUERY_STRING': query,
            'REQUEST_METHOD': self.method,
        }
        cl = self.headers.get('content-length')
        if cl:
            env['CONTENT_LENGTH'] = cl
        ct = self.headers.get('content-type')
        if ct:
            env['CONTENT_TYPE'] = ct
        form = cgi.FieldStorage(fp=StringIO(self.body or ''), environ=env)
        self._form = form
        return form

    def parse_request_line(self, line):
        """ parse the request line

            return None if the request line was invalid (either unrecognized
            or too long), or a tuple (method, path, httpver) when it was valid
        """
        if len(line) > MAX_REQUEST_LINE_SIZE:
            log.log(log.WARNING, 'request line too long (%r)' % (line,))
            return
        reqinfo = line.strip().split(None, 2)
        # XXX no HTTP/0.9 support...
        if len(reqinfo) != 3 or not reqinfo[2].startswith('HTTP/'):
            log.log(log.WARNING, 'unrecognized request line (%r)' % (line,))
            return
        self.method, self.path, self.httpver = reqinfo
        return reqinfo

    def parse_headers(self, data):
        """ parse request headers
        
            this parses the full set of headers, and returns an instance of
            mimetools.Message
        """
        headers = mimetools.Message(StringIO(data.strip()))
        self.headers = headers
        return headers

    def keepalive(self):
        """ return True if the connection should be kept open for a new request
        """
        return ((self.httpver is not None and self.headers is not None and
                 self.httpver == 'HTTP/1.1' and
                 self.headers.get('connection', '').lower() != 'close') or
                 self.headers.get('connection', '').lower() == 'keep-alive')

# this dispatches response handling to the handlers, and basically adapts
# the producer api to an interator one
class RequestProducer(object):
    """ holds information about the request and dispatches to the handler
    """
    def __init__(self, config, channel, request):
        self.config = config
        self.request = request
        self.channel = channel
        self.handler = None
        self._form = None

    def init(self, handlerclass):
        try:
            self.handler = handlerclass(self.config, self.request,
                                        self.channel)
        except:
            self.handle_exception()

    def more(self):
        if not self.handler:
            return ''
        try:
            ret = self.handler.next()
            return ret
        except StopIteration:
            self.channel.cleanup()
        except:
            self.handle_exception()
        return ''

    def handle_exception(self):
        exc, e, tb = sys.exc_info()
        self.channel.discard_buffers()
        self.handler = ErrorHandler(self.config, self.request, self.channel,
                                    exc, e, tb)
        log.log(log.ERROR, '%s - %s' % (exc, e))
        for line in traceback.format_tb(tb):
            log.log(log.DEBUG, line.rstrip())
        del tb

# this wraps an incoming connection and handles receiving the request
class HTTPChannel(asynchat.async_chat):
    """ a single connection
    """
    def __init__(self, config, sock, addr, server):
        if sock:
            # pass None as sock to not have the channel registered to
            # asyncore.socket_map (mostly useful for unit tests)
            asynchat.async_chat.__init__(self, sock)
        self.config = config
        self.addr = addr
        self.server = server
        self.initvars()

    def initvars(self):
        self.set_terminator('\n')
        self.buffer = []
        self.request = Request()
        self.done_reading = False

    def collect_incoming_data(self, data):
        self.buffer.append(data)

    def found_terminator(self):
        if self.done_reading:
            return
        req = self.request
        if req is None or req.method is None:
            line = ''.join(self.buffer)
            self.buffer = []
            reqinfo = req.parse_request_line(line)
            if reqinfo is None:
                return
        elif req.headers is None:
            self.buffer.append('\n')
            if not self.headers_end():
                # we didn't have a double nl sequence yet, so keep the data
                # coming...
                return
            data = ''.join(self.buffer)
            self.buffer = []
            headers = req.parse_headers(data)
            cl = int(headers.get('content-length', 0))
            if cl:
                # read N bytes
                self.set_terminator(cl)
            else:
                self.done_reading = True
                self.handle_request()
        else:
            self.done_reading = True
            req.body = ''.join(self.buffer)
            self.buffer[:] = []
            self.handle_request()

    def handle_request(self):
        req = self.request
        path = req.path.split('?')[0]
        handlers = self.server.handlers.get(path)
        if handlers is None:
            handler = NotFoundHandler
        else:
            handler = handlers.get(req.method)
            if handler is None:
                handler = MethodNotSupportedHandler
        rp = RequestProducer(self.config, self, req)
        rp.init(handler)
        self.push_with_producer(rp)

    def headers_end(self):
        """ returns True if the headers are all read
        """
        return (self.buffer[-2:] == ['\n', '\n'] or
                self.buffer[-3:] == ['\n', '\r', '\n'] or
                self.buffer == ['\n'] or self.buffer == ['\r', '\n'])

    def cleanup(self, allow_keepalive=True):
        """ this should be called when the response handler is done
        """
        if allow_keepalive and self.request.keepalive():
            log.log(log.DEBUG, 'allowing keepalive for %s' % (self.addr[0],))
            self.initvars()
        else:
            log.log(log.DEBUG, 'closing connection with %s' % (self.addr[0],))
            self.close_when_done()
            self.server.connections.remove(self.socket)

    def handle_error(self):
        # just raise, let the HTTPServer deal with exceptions
        raise

    def handle_expt(self):
        # XXX not sure what to do with this... but the default sucks for sure
        pass

# the server
class HTTPServer(asyncore.dispatcher):
    """ simple HTTP server

        async server that supports only GET and POST
    """
    def __init__(self, config, handlers):
        asyncore.dispatcher.__init__(self)
        self.config = config
        self.handlers = handlers
        self.connections = []
        self._piper, self._pipew = os.pipe()
        discarding_file_dispatcher(self._piper)
        self.initialized = False

    def initialize(self):
        """ call this after instantiation

            binds to the socket and starts accepting connections
        """
        try:
            self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
            self.bind((self.config.host, self.config.port))
            self.listen(5)
            self.initialized = True
        except:
            # clean up asyncore's map neatly if we fail to init
            asyncore.socket_map.pop(self.socket.fileno())
            del self.socket
            raise

    def handle_accept(self):
        """ handle a connect

            dispatch to a new httpchannel
        """
        conn, addr = self.accept()
        self.connections.append(conn)
        log.log(log.MESSAGE, 'connection from %s accepted on port %s' % addr)
        HTTPChannel(self.config, conn, addr, self)

    def serve(self, use_poll=True, count=None, timeout=None,
              raise_exceptions=False):
        """ serve the pages
        """
        try:
            asyncore.loop(use_poll=use_poll, count=count,
                          timeout=timeout)
        except (KeyboardInterrupt, SystemExit):
            raise # XXX ?
        except:
            if raise_exceptions:
                raise
            exc, e, tb = sys.exc_info()
            log.log(log.ERROR, '%s - %s' % (exc, e))
            for line in traceback.format_tb(tb):
                log.log(log.DEBUG, line.rstrip())

    def poll_next(self):
        """ call this to make poll() return

            this can be used from other threads to make the poll() call stop
            blocking once, this to allow 'feeding' of data to the http channel
            from other threads (this won't work without 'kicking' poll)
        """
        # write something to the pipe to trigger poll() to return
        os.write(self._pipew, '\0')

    def close(self):
        """ close the pipe and socket
        """
        for fd in (self._piper, self._pipew):
            try:
                os.close(fd)
            except:
                pass
        for conn in self.connections:
            try:
                conn.close()
            except socket.error, e:
                log.log(log.DEBUG, 'error closing connection: %s' % (e,))
        asyncore.dispatcher.close(self)

    def handle_error(self):
        """ handle an exception
        """
        exc, e, tb = sys.exc_info()
        log.log(log.ERROR, '%s - %s')
        for line in traceback.format_tb(tb):
            log.log(log.DEBUG, line.rstrip())
        raise exc, e, tb

# wraps the poll_next pipe's write fd
class discarding_file_dispatcher(asyncore.file_dispatcher):
    """ this is used for the pipe for HTTPServer.poll_next()
    """
    def handle_read(self):
        while 1:
            try:
                data = self.recv(2)
            except OSError:
                break
            if len(data) > 2:
                break

    def handle_expt(self):
        pass

if __name__ == '__main__':
    # boot a test server
    import config
    def root(config, request, channel):
        return (line + '\r\n' for line in (
            '%s 200 Ok' % (request.httpver,),
            'Content-Type: text/plain',
            'Content-Length: 34',
            '',
            'async httpserver for httpmp v0.1',
        ))
    s = HTTPServer(config, {'/': {'GET': root}})
    s.initialize()
    print 'starting server at port', config.port
    s.serve()
    print 'closed down'


