From 732e7edba494ef7beef24e81ed35de6fa0dc9ede Mon Sep 17 00:00:00 2001 From: Markus Birth Date: Sun, 3 Aug 2025 01:38:31 +0100 Subject: [PATCH] Initial commit Signed-off-by: Markus Birth --- .gitignore | 4 + Makefile | 7 + README.md | 15 + main.py | 32 + microdot/__init__.py | 2 + microdot/helpers.py | 8 + microdot/microdot.py | 1532 ++++++++++++++++++++++++++++++++++++++ microdot/multipart.py | 291 ++++++++ microdot/session.py | 155 ++++ microdot/sse.py | 126 ++++ microdot/utemplate.py | 70 ++ requests/__init__.mpy | Bin 0 -> 2434 bytes templates/index.html | 36 + tflcountdown/__init__.py | 149 ++++ urequests.mpy | Bin 0 -> 101 bytes utemplate/compiled.py | 14 + utemplate/recompile.py | 21 + utemplate/source.py | 188 +++++ 18 files changed, 2650 insertions(+) create mode 100644 .gitignore create mode 100644 Makefile create mode 100644 README.md create mode 100644 main.py create mode 100644 microdot/__init__.py create mode 100644 microdot/helpers.py create mode 100644 microdot/microdot.py create mode 100644 microdot/multipart.py create mode 100644 microdot/session.py create mode 100644 microdot/sse.py create mode 100644 microdot/utemplate.py create mode 100644 requests/__init__.mpy create mode 100644 templates/index.html create mode 100644 tflcountdown/__init__.py create mode 100644 urequests.mpy create mode 100644 utemplate/compiled.py create mode 100644 utemplate/recompile.py create mode 100644 utemplate/source.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..41fcfd2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +.DS_Store + +# utemplate files +/templates/*.py diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..20eb916 --- /dev/null +++ b/Makefile @@ -0,0 +1,7 @@ +.PHONY: clean +clean: + rm -f templates/*.py + +.PHONY: run +run: + micropython main.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..e533930 --- /dev/null +++ b/README.md @@ -0,0 +1,15 @@ +TfL Countdown +============= + +This is a personalised TfL Countdown website that's supposed to +show when the next bus departs from your nearest bus stop(s). + +It's written in [MicroPython](https://micropython.org). + + make run + +This runs the `main.py` using `micropython` (needs to be installed). + + make clean + +This cleans all the compiled templates from the `templates` directory. diff --git a/main.py b/main.py new file mode 100644 index 0000000..c2ece1e --- /dev/null +++ b/main.py @@ -0,0 +1,32 @@ +from microdot import Microdot +from microdot.utemplate import Template +import tflcountdown as tfl + +# utemplate doc: https://github.com/pfalcon/utemplate + +API_KEY = "NOT_YET_REQUIRED" +# Stop-IDs from https://tfl.gov.uk/bus-stops.csv +STOP_IDS = { + "Acton Vale (N)": "1597", + "Acton Vale (S)": "1598", + "Abinger Road (N)": "11333", + "Abinger Road (S)": "11334" + + #H1227,58839,490018676N,Hail & Ride Larden Road,521409,179656,350,6408,1 + #H1228,N/A,490018676S,Hail & Ride Larden Road,521420,179655,170,6408,1 +} +LINE_IDS = { + "272": "272" +} + +app = Microdot() +tflc = tfl.TflCountdown(API_KEY) + +@app.route("/") +async def index(request): + response = tflc.get_countdown(["1597", "1598", "11333", "11334", "R0199"]) + data = tflc.parse_countdown(response.text) + print(repr(data)) + return Template("index.html").render(data), {"Content-Type": "text/html"} + +app.run(port=5001, debug=True) diff --git a/microdot/__init__.py b/microdot/__init__.py new file mode 100644 index 0000000..2637085 --- /dev/null +++ b/microdot/__init__.py @@ -0,0 +1,2 @@ +from microdot.microdot import Microdot, Request, Response, abort, redirect, \ + send_file, URLPattern, AsyncBytesIO, iscoroutine # noqa: F401 diff --git a/microdot/helpers.py b/microdot/helpers.py new file mode 100644 index 0000000..664e58c --- /dev/null +++ b/microdot/helpers.py @@ -0,0 +1,8 @@ +try: + from functools import wraps +except ImportError: # pragma: no cover + # MicroPython does not currently implement functools.wraps + def wraps(wrapped): + def _(wrapper): + return wrapper + return _ diff --git a/microdot/microdot.py b/microdot/microdot.py new file mode 100644 index 0000000..eb1d8d0 --- /dev/null +++ b/microdot/microdot.py @@ -0,0 +1,1532 @@ +""" +microdot +-------- + +The ``microdot`` module defines a few classes that help implement HTTP-based +servers for MicroPython and standard Python. +""" +import asyncio +import io +import re +import time + +try: + import orjson as json +except ImportError: + import json + +try: + from inspect import iscoroutinefunction, iscoroutine + from functools import partial + + async def invoke_handler(handler, *args, **kwargs): + """Invoke a handler and return the result. + + This method runs sync handlers in a thread pool executor. + """ + if iscoroutinefunction(handler): + ret = await handler(*args, **kwargs) + else: + ret = await asyncio.get_running_loop().run_in_executor( + None, partial(handler, *args, **kwargs)) + return ret +except ImportError: # pragma: no cover + def iscoroutine(coro): + return hasattr(coro, 'send') and hasattr(coro, 'throw') + + async def invoke_handler(handler, *args, **kwargs): + """Invoke a handler and return the result. + + This method runs sync handlers in the asyncio thread, which can + potentially cause blocking and performance issues. + """ + ret = handler(*args, **kwargs) + if iscoroutine(ret): + ret = await ret + return ret + +try: + from sys import print_exception +except ImportError: # pragma: no cover + import traceback + + def print_exception(exc): + traceback.print_exc() + +MUTED_SOCKET_ERRORS = [ + 32, # Broken pipe + 54, # Connection reset by peer + 104, # Connection reset by peer + 128, # Operation on closed socket +] + + +def urldecode(s): + if isinstance(s, str): + s = s.encode() + s = s.replace(b'+', b' ') + parts = s.split(b'%') + if len(parts) == 1: + return s.decode() + result = [parts[0]] + for item in parts[1:]: + if item == b'': + result.append(b'%') + else: + code = item[:2] + result.append(bytes([int(code, 16)])) + result.append(item[2:]) + return b''.join(result).decode() + + +def urlencode(s): + return s.replace('+', '%2B').replace(' ', '+').replace( + '%', '%25').replace('?', '%3F').replace('#', '%23').replace( + '&', '%26').replace('=', '%3D') + + +class NoCaseDict(dict): + """A subclass of dictionary that holds case-insensitive keys. + + :param initial_dict: an initial dictionary of key/value pairs to + initialize this object with. + + Example:: + + >>> d = NoCaseDict() + >>> d['Content-Type'] = 'text/html' + >>> print(d['Content-Type']) + text/html + >>> print(d['content-type']) + text/html + >>> print(d['CONTENT-TYPE']) + text/html + >>> del d['cOnTeNt-TyPe'] + >>> print(d) + {} + """ + def __init__(self, initial_dict=None): + super().__init__(initial_dict or {}) + self.keymap = {k.lower(): k for k in self.keys() if k.lower() != k} + + def __setitem__(self, key, value): + kl = key.lower() + key = self.keymap.get(kl, key) + if kl != key: + self.keymap[kl] = key + super().__setitem__(key, value) + + def __getitem__(self, key): + kl = key.lower() + return super().__getitem__(self.keymap.get(kl, kl)) + + def __delitem__(self, key): + kl = key.lower() + super().__delitem__(self.keymap.get(kl, kl)) + + def __contains__(self, key): + kl = key.lower() + return self.keymap.get(kl, kl) in self.keys() + + def get(self, key, default=None): + kl = key.lower() + return super().get(self.keymap.get(kl, kl), default) + + def update(self, other_dict): + for key, value in other_dict.items(): + self[key] = value + + +def mro(cls): # pragma: no cover + """Return the method resolution order of a class. + + This is a helper function that returns the method resolution order of a + class. It is used by Microdot to find the best error handler to invoke for + the raised exception. + + In CPython, this function returns the ``__mro__`` attribute of the class. + In MicroPython, this function implements a recursive depth-first scanning + of the class hierarchy. + """ + if hasattr(cls, 'mro'): + return cls.__mro__ + + def _mro(cls): + m = [cls] + for base in cls.__bases__: + m += _mro(base) + return m + + mro_list = _mro(cls) + + # If a class appears multiple times (due to multiple inheritance) remove + # all but the last occurence. This matches the method resolution order + # of MicroPython, but not CPython. + mro_pruned = [] + for i in range(len(mro_list)): + base = mro_list.pop(0) + if base not in mro_list: + mro_pruned.append(base) + return mro_pruned + + +class MultiDict(dict): + """A subclass of dictionary that can hold multiple values for the same + key. It is used to hold key/value pairs decoded from query strings and + form submissions. + + :param initial_dict: an initial dictionary of key/value pairs to + initialize this object with. + + Example:: + + >>> d = MultiDict() + >>> d['sort'] = 'name' + >>> d['sort'] = 'email' + >>> print(d['sort']) + 'name' + >>> print(d.getlist('sort')) + ['name', 'email'] + """ + def __init__(self, initial_dict=None): + super().__init__() + if initial_dict: + for key, value in initial_dict.items(): + self[key] = value + + def __setitem__(self, key, value): + if key not in self: + super().__setitem__(key, []) + super().__getitem__(key).append(value) + + def __getitem__(self, key): + return super().__getitem__(key)[0] + + def get(self, key, default=None, type=None): + """Return the value for a given key. + + :param key: The key to retrieve. + :param default: A default value to use if the key does not exist. + :param type: A type conversion callable to apply to the value. + + If the multidict contains more than one value for the requested key, + this method returns the first value only. + + Example:: + + >>> d = MultiDict() + >>> d['age'] = '42' + >>> d.get('age') + '42' + >>> d.get('age', type=int) + 42 + >>> d.get('name', default='noname') + 'noname' + """ + if key not in self: + return default + value = self[key] + if type is not None: + value = type(value) + return value + + def getlist(self, key, type=None): + """Return all the values for a given key. + + :param key: The key to retrieve. + :param type: A type conversion callable to apply to the values. + + If the requested key does not exist in the dictionary, this method + returns an empty list. + + Example:: + + >>> d = MultiDict() + >>> d.getlist('items') + [] + >>> d['items'] = '3' + >>> d.getlist('items') + ['3'] + >>> d['items'] = '56' + >>> d.getlist('items') + ['3', '56'] + >>> d.getlist('items', type=int) + [3, 56] + """ + if key not in self: + return [] + values = super().__getitem__(key) + if type is not None: + values = [type(value) for value in values] + return values + + +class AsyncBytesIO: + """An async wrapper for BytesIO.""" + def __init__(self, data): + self.stream = io.BytesIO(data) + + async def read(self, n=-1): + return self.stream.read(n) + + async def readline(self): # pragma: no cover + return self.stream.readline() + + async def readexactly(self, n): # pragma: no cover + return self.stream.read(n) + + async def readuntil(self, separator=b'\n'): # pragma: no cover + return self.stream.readuntil(separator=separator) + + async def awrite(self, data): # pragma: no cover + return self.stream.write(data) + + async def aclose(self): # pragma: no cover + pass + + +class Request: + """An HTTP request.""" + #: Specify the maximum payload size that is accepted. Requests with larger + #: payloads will be rejected with a 413 status code. Applications can + #: change this maximum as necessary. + #: + #: Example:: + #: + #: Request.max_content_length = 1 * 1024 * 1024 # 1MB requests allowed + max_content_length = 16 * 1024 + + #: Specify the maximum payload size that can be stored in ``body``. + #: Requests with payloads that are larger than this size and up to + #: ``max_content_length`` bytes will be accepted, but the application will + #: only be able to access the body of the request by reading from + #: ``stream``. Set to 0 if you always access the body as a stream. + #: + #: Example:: + #: + #: Request.max_body_length = 4 * 1024 # up to 4KB bodies read + max_body_length = 16 * 1024 + + #: Specify the maximum length allowed for a line in the request. Requests + #: with longer lines will not be correctly interpreted. Applications can + #: change this maximum as necessary. + #: + #: Example:: + #: + #: Request.max_readline = 16 * 1024 # 16KB lines allowed + max_readline = 2 * 1024 + + class G: + pass + + def __init__(self, app, client_addr, method, url, http_version, headers, + body=None, stream=None, sock=None, url_prefix='', + subapp=None): + #: The application instance to which this request belongs. + self.app = app + #: The address of the client, as a tuple (host, port). + self.client_addr = client_addr + #: The HTTP method of the request. + self.method = method + #: The request URL, including the path and query string. + self.url = url + #: The URL prefix, if the endpoint comes from a mounted + #: sub-application, or else ''. + self.url_prefix = url_prefix + #: The sub-application instance, or `None` if this isn't a mounted + #: endpoint. + self.subapp = subapp + #: The path portion of the URL. + self.path = url + #: The query string portion of the URL. + self.query_string = None + #: The parsed query string, as a + #: :class:`MultiDict ` object. + self.args = {} + #: A dictionary with the headers included in the request. + self.headers = headers + #: A dictionary with the cookies included in the request. + self.cookies = {} + #: The parsed ``Content-Length`` header. + self.content_length = 0 + #: The parsed ``Content-Type`` header. + self.content_type = None + #: A general purpose container for applications to store data during + #: the life of the request. + self.g = Request.G() + + self.http_version = http_version + if '?' in self.path: + self.path, self.query_string = self.path.split('?', 1) + self.args = self._parse_urlencoded(self.query_string) + + if 'Content-Length' in self.headers: + self.content_length = int(self.headers['Content-Length']) + if 'Content-Type' in self.headers: + self.content_type = self.headers['Content-Type'] + if 'Cookie' in self.headers: + for cookie in self.headers['Cookie'].split(';'): + name, value = cookie.strip().split('=', 1) + self.cookies[name] = value + + self._body = body + self.body_used = False + self._stream = stream + self.sock = sock + self._json = None + self._form = None + self._files = None + self.after_request_handlers = [] + + @staticmethod + async def create(app, client_reader, client_writer, client_addr): + """Create a request object. + + :param app: The Microdot application instance. + :param client_reader: An input stream from where the request data can + be read. + :param client_writer: An output stream where the response data can be + written. + :param client_addr: The address of the client, as a tuple. + + This method is a coroutine. It returns a newly created ``Request`` + object. + """ + # request line + line = (await Request._safe_readline(client_reader)).strip().decode() + if not line: # pragma: no cover + return None + method, url, http_version = line.split() + http_version = http_version.split('/', 1)[1] + + # headers + headers = NoCaseDict() + content_length = 0 + while True: + line = (await Request._safe_readline( + client_reader)).strip().decode() + if line == '': + break + header, value = line.split(':', 1) + value = value.strip() + headers[header] = value + if header.lower() == 'content-length': + content_length = int(value) + + # body + body = b'' + if content_length and content_length <= Request.max_body_length: + body = await client_reader.readexactly(content_length) + stream = None + else: + body = b'' + stream = client_reader + + return Request(app, client_addr, method, url, http_version, headers, + body=body, stream=stream, + sock=(client_reader, client_writer)) + + def _parse_urlencoded(self, urlencoded): + data = MultiDict() + if len(urlencoded) > 0: # pragma: no branch + if isinstance(urlencoded, str): + for kv in [pair.split('=', 1) + for pair in urlencoded.split('&') if pair]: + data[urldecode(kv[0])] = urldecode(kv[1]) \ + if len(kv) > 1 else '' + elif isinstance(urlencoded, bytes): # pragma: no branch + for kv in [pair.split(b'=', 1) + for pair in urlencoded.split(b'&') if pair]: + data[urldecode(kv[0])] = urldecode(kv[1]) \ + if len(kv) > 1 else b'' + return data + + @property + def body(self): + """The body of the request, as bytes.""" + return self._body + + @property + def stream(self): + """The body of the request, as a bytes stream.""" + if self._stream is None: + self._stream = AsyncBytesIO(self._body) + return self._stream + + @property + def json(self): + """The parsed JSON body, or ``None`` if the request does not have a + JSON body.""" + if self._json is None: + if self.content_type is None: + return None + mime_type = self.content_type.split(';')[0] + if mime_type != 'application/json': + return None + self._json = json.loads(self.body.decode()) + return self._json + + @property + def form(self): + """The parsed form submission body, as a + :class:`MultiDict ` object, or ``None`` if the + request does not have a form submission. + + Forms that are URL encoded are processed by default. For multipart + forms to be processed, the + :func:`with_form_data ` + decorator must be added to the route. + """ + if self._form is None: + if self.content_type is None: + return None + mime_type = self.content_type.split(';')[0] + if mime_type != 'application/x-www-form-urlencoded': + return None + self._form = self._parse_urlencoded(self.body) + return self._form + + @property + def files(self): + """The files uploaded in the request as a dictionary, or ``None`` if + the request does not have any files. + + The :func:`with_form_data ` + decorator must be added to the route that receives file uploads for + this property to be set. + """ + return self._files + + def after_request(self, f): + """Register a request-specific function to run after the request is + handled. Request-specific after request handlers run at the very end, + after the application's own after request handlers. The function must + take two arguments, the request and response objects. The return value + of the function must be the updated response object. + + Example:: + + @app.route('/') + def index(request): + # register a request-specific after request handler + @req.after_request + def func(request, response): + # ... + return response + + return 'Hello, World!' + + Note that the function is not called if the request handler raises an + exception and an error response is returned instead. + """ + self.after_request_handlers.append(f) + return f + + @staticmethod + async def _safe_readline(stream): + line = (await stream.readline()) + if len(line) > Request.max_readline: + raise ValueError('line too long') + return line + + +class Response: + """An HTTP response class. + + :param body: The body of the response. If a dictionary or list is given, + a JSON formatter is used to generate the body. If a file-like + object or an async generator is given, a streaming response is + used. If a string is given, it is encoded from UTF-8. Else, + the body should be a byte sequence. + :param status_code: The numeric HTTP status code of the response. The + default is 200. + :param headers: A dictionary of headers to include in the response. + :param reason: A custom reason phrase to add after the status code. The + default is "OK" for responses with a 200 status code and + "N/A" for any other status codes. + """ + types_map = { + 'css': 'text/css', + 'gif': 'image/gif', + 'html': 'text/html', + 'jpg': 'image/jpeg', + 'js': 'application/javascript', + 'json': 'application/json', + 'png': 'image/png', + 'txt': 'text/plain', + 'svg': 'image/svg+xml', + } + + send_file_buffer_size = 1024 + + #: The content type to use for responses that do not explicitly define a + #: ``Content-Type`` header. + default_content_type = 'text/plain' + + #: The default cache control max age used by :meth:`send_file`. A value + #: of ``None`` means that no ``Cache-Control`` header is added. + default_send_file_max_age = None + + #: Special response used to signal that a response does not need to be + #: written to the client. Used to exit WebSocket connections cleanly. + already_handled = None + + def __init__(self, body='', status_code=200, headers=None, reason=None): + if body is None and status_code == 200: + body = '' + status_code = 204 + self.status_code = status_code + self.headers = NoCaseDict(headers or {}) + self.reason = reason + if isinstance(body, (dict, list)): + body = json.dumps(body) + self.headers['Content-Type'] = 'application/json; charset=UTF-8' + if isinstance(body, str): + self.body = body.encode() + else: + # this applies to bytes, file-like objects or generators + self.body = body + self.is_head = False + + def set_cookie(self, cookie, value, path=None, domain=None, expires=None, + max_age=None, secure=False, http_only=False, + partitioned=False): + """Add a cookie to the response. + + :param cookie: The cookie's name. + :param value: The cookie's value. + :param path: The cookie's path. + :param domain: The cookie's domain. + :param expires: The cookie expiration time, as a ``datetime`` object + or a correctly formatted string. + :param max_age: The cookie's ``Max-Age`` value. + :param secure: The cookie's ``secure`` flag. + :param http_only: The cookie's ``HttpOnly`` flag. + :param partitioned: Whether the cookie is partitioned. + """ + http_cookie = '{cookie}={value}'.format(cookie=cookie, value=value) + if path: + http_cookie += '; Path=' + path + if domain: + http_cookie += '; Domain=' + domain + if expires: + if isinstance(expires, str): + http_cookie += '; Expires=' + expires + else: # pragma: no cover + http_cookie += '; Expires=' + time.strftime( + '%a, %d %b %Y %H:%M:%S GMT', expires.timetuple()) + if max_age is not None: + http_cookie += '; Max-Age=' + str(max_age) + if secure: + http_cookie += '; Secure' + if http_only: + http_cookie += '; HttpOnly' + if partitioned: + http_cookie += '; Partitioned' + if 'Set-Cookie' in self.headers: + self.headers['Set-Cookie'].append(http_cookie) + else: + self.headers['Set-Cookie'] = [http_cookie] + + def delete_cookie(self, cookie, **kwargs): + """Delete a cookie. + + :param cookie: The cookie's name. + :param kwargs: Any cookie opens and flags supported by + ``set_cookie()`` except ``expires`` and ``max_age``. + """ + self.set_cookie(cookie, '', expires='Thu, 01 Jan 1970 00:00:01 GMT', + max_age=0, **kwargs) + + def complete(self): + if isinstance(self.body, bytes) and \ + 'Content-Length' not in self.headers: + self.headers['Content-Length'] = str(len(self.body)) + if 'Content-Type' not in self.headers: + self.headers['Content-Type'] = self.default_content_type + if 'charset=' not in self.headers['Content-Type']: + self.headers['Content-Type'] += '; charset=UTF-8' + + async def write(self, stream): + self.complete() + + try: + # status code + reason = self.reason if self.reason is not None else \ + ('OK' if self.status_code == 200 else 'N/A') + await stream.awrite('HTTP/1.0 {status_code} {reason}\r\n'.format( + status_code=self.status_code, reason=reason).encode()) + + # headers + for header, value in self.headers.items(): + values = value if isinstance(value, list) else [value] + for value in values: + await stream.awrite('{header}: {value}\r\n'.format( + header=header, value=value).encode()) + await stream.awrite(b'\r\n') + + # body + if not self.is_head: + iter = self.body_iter() + async for body in iter: + if isinstance(body, str): # pragma: no cover + body = body.encode() + try: + await stream.awrite(body) + except OSError as exc: # pragma: no cover + if exc.errno in MUTED_SOCKET_ERRORS or \ + exc.args[0] == 'Connection lost': + if hasattr(iter, 'aclose'): + await iter.aclose() + raise + if hasattr(iter, 'aclose'): # pragma: no branch + await iter.aclose() + + except OSError as exc: # pragma: no cover + if exc.errno in MUTED_SOCKET_ERRORS or \ + exc.args[0] == 'Connection lost': + pass + else: + raise + + def body_iter(self): + if hasattr(self.body, '__anext__'): + # response body is an async generator + return self.body + + response = self + + class iter: + ITER_UNKNOWN = 0 + ITER_SYNC_GEN = 1 + ITER_FILE_OBJ = 2 + ITER_NO_BODY = -1 + + def __aiter__(self): + if response.body: + self.i = self.ITER_UNKNOWN # need to determine type + else: + self.i = self.ITER_NO_BODY + return self + + async def __anext__(self): + if self.i == self.ITER_NO_BODY: + await self.aclose() + raise StopAsyncIteration + if self.i == self.ITER_UNKNOWN: + if hasattr(response.body, 'read'): + self.i = self.ITER_FILE_OBJ + elif hasattr(response.body, '__next__'): + self.i = self.ITER_SYNC_GEN + return next(response.body) + else: + self.i = self.ITER_NO_BODY + return response.body + elif self.i == self.ITER_SYNC_GEN: + try: + return next(response.body) + except StopIteration: + await self.aclose() + raise StopAsyncIteration + buf = response.body.read(response.send_file_buffer_size) + if iscoroutine(buf): # pragma: no cover + buf = await buf + if len(buf) < response.send_file_buffer_size: + self.i = self.ITER_NO_BODY + return buf + + async def aclose(self): + if hasattr(response.body, 'close'): + result = response.body.close() + if iscoroutine(result): # pragma: no cover + await result + + return iter() + + @classmethod + def redirect(cls, location, status_code=302): + """Return a redirect response. + + :param location: The URL to redirect to. + :param status_code: The 3xx status code to use for the redirect. The + default is 302. + """ + if '\x0d' in location or '\x0a' in location: + raise ValueError('invalid redirect URL') + return cls(status_code=status_code, headers={'Location': location}) + + @classmethod + def send_file(cls, filename, status_code=200, content_type=None, + stream=None, max_age=None, compressed=False, + file_extension=''): + """Send file contents in a response. + + :param filename: The filename of the file. + :param status_code: The 3xx status code to use for the redirect. The + default is 302. + :param content_type: The ``Content-Type`` header to use in the + response. If omitted, it is generated + automatically from the file extension of the + ``filename`` parameter. + :param stream: A file-like object to read the file contents from. If + a stream is given, the ``filename`` parameter is only + used when generating the ``Content-Type`` header. + :param max_age: The ``Cache-Control`` header's ``max-age`` value in + seconds. If omitted, the value of the + :attr:`Response.default_send_file_max_age` attribute is + used. + :param compressed: Whether the file is compressed. If ``True``, the + ``Content-Encoding`` header is set to ``gzip``. A + string with the header value can also be passed. + Note that when using this option the file must have + been compressed beforehand. This option only sets + the header. + :param file_extension: A file extension to append to the ``filename`` + parameter when opening the file, including the + dot. The extension given here is not considered + when generating the ``Content-Type`` header. + + Security note: The filename is assumed to be trusted. Never pass + filenames provided by the user without validating and sanitizing them + first. + """ + if content_type is None: + if compressed and filename.endswith('.gz'): + ext = filename[:-3].split('.')[-1] + else: + ext = filename.split('.')[-1] + if ext in Response.types_map: + content_type = Response.types_map[ext] + else: + content_type = 'application/octet-stream' + headers = {'Content-Type': content_type} + + if max_age is None: + max_age = cls.default_send_file_max_age + if max_age is not None: + headers['Cache-Control'] = 'max-age={}'.format(max_age) + + if compressed: + headers['Content-Encoding'] = compressed \ + if isinstance(compressed, str) else 'gzip' + + f = stream or open(filename + file_extension, 'rb') + return cls(body=f, status_code=status_code, headers=headers) + + +class URLPattern(): + """A class that represents the URL pattern for a route. + + :param url_pattern: The route URL pattern, which can include static and + dynamic path segments. Dynamic segments are enclosed in + ``<`` and ``>``. The type of the segment can be given + as a prefix, separated from the name with a colon. + Supported types are ``string`` (the default), + ``int`` and ``path``. Custom types can be registered + using the :meth:`URLPattern.register_type` method. + """ + + segment_patterns = { + 'string': '/([^/]+)', + 'int': '/(-?\\d+)', + 'path': '/(.+)', + } + segment_parsers = { + 'int': lambda value: int(value), + } + + @classmethod + def register_type(cls, type_name, pattern='[^/]+', parser=None): + """Register a new URL segment type. + + :param type_name: The name of the segment type to register. + :param pattern: The regular expression pattern to use when matching + this segment type. If not given, a default matcher for + a single path segment is used. + :param parser: A callable that will be used to parse and transform the + value of the segment. If omitted, the value is returned + as a string. + """ + cls.segment_patterns[type_name] = '/({})'.format(pattern) + cls.segment_parsers[type_name] = parser + + def __init__(self, url_pattern): + self.url_pattern = url_pattern + self.segments = [] + self.regex = None + + def compile(self): + """Generate a regular expression for the URL pattern. + + This method is automatically invoked the first time the URL pattern is + matched against a path. + """ + pattern = '' + for segment in self.url_pattern.lstrip('/').split('/'): + if segment and segment[0] == '<': + if segment[-1] != '>': + raise ValueError('invalid URL pattern') + segment = segment[1:-1] + if ':' in segment: + type_, name = segment.rsplit(':', 1) + else: + type_ = 'string' + name = segment + parser = None + if type_.startswith('re:'): + pattern += '/({pattern})'.format(pattern=type_[3:]) + else: + if type_ not in self.segment_patterns: + raise ValueError('invalid URL segment type') + pattern += self.segment_patterns[type_] + parser = self.segment_parsers.get(type_) + self.segments.append({'parser': parser, 'name': name, + 'type': type_}) + else: + pattern += '/' + segment + self.segments.append({'parser': None}) + self.regex = re.compile('^' + pattern + '$') + return self.regex + + def match(self, path): + """Match a path against the URL pattern. + + Returns a dictionary with the values of all dynamic path segments if a + matche is found, or ``None`` if the path does not match this pattern. + """ + args = {} + g = (self.regex or self.compile()).match(path) + if not g: + return + i = 1 + for segment in self.segments: + if 'name' not in segment: + continue + arg = g.group(i) + if segment['parser']: + arg = self.segment_parsers[segment['type']](arg) + if arg is None: + return + args[segment['name']] = arg + i += 1 + return args + + def __repr__(self): # pragma: no cover + return 'URLPattern: {}'.format(self.url_pattern) + + +class HTTPException(Exception): + def __init__(self, status_code, reason=None): + self.status_code = status_code + self.reason = reason or str(status_code) + ' error' + + def __repr__(self): # pragma: no cover + return 'HTTPException: {}'.format(self.status_code) + + +class Microdot: + """An HTTP application class. + + This class implements an HTTP application instance and is heavily + influenced by the ``Flask`` class of the Flask framework. It is typically + declared near the start of the main application script. + + Example:: + + from microdot import Microdot + + app = Microdot() + """ + + def __init__(self): + self.url_map = [] + self.before_request_handlers = [] + self.after_request_handlers = [] + self.after_error_request_handlers = [] + self.error_handlers = {} + self.shutdown_requested = False + self.options_handler = self.default_options_handler + self.debug = False + self.server = None + + def route(self, url_pattern, methods=None): + """Decorator that is used to register a function as a request handler + for a given URL. + + :param url_pattern: The URL pattern that will be compared against + incoming requests. + :param methods: The list of HTTP methods to be handled by the + decorated function. If omitted, only ``GET`` requests + are handled. + + The URL pattern can be a static path (for example, ``/users`` or + ``/api/invoices/search``) or a path with dynamic components enclosed + in ``<`` and ``>`` (for example, ``/users/`` or + ``/invoices//products``). Dynamic path components can also + include a type prefix, separated from the name with a colon (for + example, ``/users/``). The type can be ``string`` (the + default), ``int``, ``path`` or ``re:[regular-expression]``. + + The first argument of the decorated function must be + the request object. Any path arguments that are specified in the URL + pattern are passed as keyword arguments. The return value of the + function must be a :class:`Response` instance, or the arguments to + be passed to this class. + + Example:: + + @app.route('/') + def index(request): + return 'Hello, world!' + """ + def decorated(f): + self.url_map.append( + ([m.upper() for m in (methods or ['GET'])], + URLPattern(url_pattern), f, '', None)) + return f + return decorated + + def get(self, url_pattern): + """Decorator that is used to register a function as a ``GET`` request + handler for a given URL. + + :param url_pattern: The URL pattern that will be compared against + incoming requests. + + This decorator can be used as an alias to the ``route`` decorator with + ``methods=['GET']``. + + Example:: + + @app.get('/users/') + def get_user(request, id): + # ... + """ + return self.route(url_pattern, methods=['GET']) + + def post(self, url_pattern): + """Decorator that is used to register a function as a ``POST`` request + handler for a given URL. + + :param url_pattern: The URL pattern that will be compared against + incoming requests. + + This decorator can be used as an alias to the``route`` decorator with + ``methods=['POST']``. + + Example:: + + @app.post('/users') + def create_user(request): + # ... + """ + return self.route(url_pattern, methods=['POST']) + + def put(self, url_pattern): + """Decorator that is used to register a function as a ``PUT`` request + handler for a given URL. + + :param url_pattern: The URL pattern that will be compared against + incoming requests. + + This decorator can be used as an alias to the ``route`` decorator with + ``methods=['PUT']``. + + Example:: + + @app.put('/users/') + def edit_user(request, id): + # ... + """ + return self.route(url_pattern, methods=['PUT']) + + def patch(self, url_pattern): + """Decorator that is used to register a function as a ``PATCH`` request + handler for a given URL. + + :param url_pattern: The URL pattern that will be compared against + incoming requests. + + This decorator can be used as an alias to the ``route`` decorator with + ``methods=['PATCH']``. + + Example:: + + @app.patch('/users/') + def edit_user(request, id): + # ... + """ + return self.route(url_pattern, methods=['PATCH']) + + def delete(self, url_pattern): + """Decorator that is used to register a function as a ``DELETE`` + request handler for a given URL. + + :param url_pattern: The URL pattern that will be compared against + incoming requests. + + This decorator can be used as an alias to the ``route`` decorator with + ``methods=['DELETE']``. + + Example:: + + @app.delete('/users/') + def delete_user(request, id): + # ... + """ + return self.route(url_pattern, methods=['DELETE']) + + def before_request(self, f): + """Decorator to register a function to run before each request is + handled. The decorated function must take a single argument, the + request object. + + Example:: + + @app.before_request + def func(request): + # ... + """ + self.before_request_handlers.append(f) + return f + + def after_request(self, f): + """Decorator to register a function to run after each request is + handled. The decorated function must take two arguments, the request + and response objects. The return value of the function must be an + updated response object. + + Example:: + + @app.after_request + def func(request, response): + # ... + return response + """ + self.after_request_handlers.append(f) + return f + + def after_error_request(self, f): + """Decorator to register a function to run after an error response is + generated. The decorated function must take two arguments, the request + and response objects. The return value of the function must be an + updated response object. The handler is invoked for error responses + generated by Microdot, as well as those returned by application-defined + error handlers. + + Example:: + + @app.after_error_request + def func(request, response): + # ... + return response + """ + self.after_error_request_handlers.append(f) + return f + + def errorhandler(self, status_code_or_exception_class): + """Decorator to register a function as an error handler. Error handler + functions for numeric HTTP status codes must accept a single argument, + the request object. Error handler functions for Python exceptions + must accept two arguments, the request object and the exception + object. + + :param status_code_or_exception_class: The numeric HTTP status code or + Python exception class to + handle. + + Examples:: + + @app.errorhandler(404) + def not_found(request): + return 'Not found' + + @app.errorhandler(RuntimeError) + def runtime_error(request, exception): + return 'Runtime error' + """ + def decorated(f): + self.error_handlers[status_code_or_exception_class] = f + return f + return decorated + + def mount(self, subapp, url_prefix='', local=False): + """Mount a sub-application, optionally under the given URL prefix. + + :param subapp: The sub-application to mount. + :param url_prefix: The URL prefix to mount the application under. + :param local: When set to ``True``, the before, after and error request + handlers only apply to endpoints defined in the + sub-application. When ``False``, they apply to the entire + application. The default is ``False``. + """ + for methods, pattern, handler, _prefix, _subapp in subapp.url_map: + self.url_map.append( + (methods, URLPattern(url_prefix + pattern.url_pattern), + handler, url_prefix + _prefix, _subapp or subapp)) + if not local: + for handler in subapp.before_request_handlers: + self.before_request_handlers.append(handler) + subapp.before_request_handlers = [] + for handler in subapp.after_request_handlers: + self.after_request_handlers.append(handler) + subapp.after_request_handlers = [] + for handler in subapp.after_error_request_handlers: + self.after_error_request_handlers.append(handler) + subapp.after_error_request_handlers = [] + for status_code, handler in subapp.error_handlers.items(): + self.error_handlers[status_code] = handler + subapp.error_handlers = {} + + @staticmethod + def abort(status_code, reason=None): + """Abort the current request and return an error response with the + given status code. + + :param status_code: The numeric status code of the response. + :param reason: The reason for the response, which is included in the + response body. + + Example:: + + from microdot import abort + + @app.route('/users/') + def get_user(id): + user = get_user_by_id(id) + if user is None: + abort(404) + return user.to_dict() + """ + raise HTTPException(status_code, reason) + + async def start_server(self, host='0.0.0.0', port=5000, debug=False, + ssl=None): + """Start the Microdot web server as a coroutine. This coroutine does + not normally return, as the server enters an endless listening loop. + The :func:`shutdown` function provides a method for terminating the + server gracefully. + + :param host: The hostname or IP address of the network interface that + will be listening for requests. A value of ``'0.0.0.0'`` + (the default) indicates that the server should listen for + requests on all the available interfaces, and a value of + ``127.0.0.1`` indicates that the server should listen + for requests only on the internal networking interface of + the host. + :param port: The port number to listen for requests. The default is + port 5000. + :param debug: If ``True``, the server logs debugging information. The + default is ``False``. + :param ssl: An ``SSLContext`` instance or ``None`` if the server should + not use TLS. The default is ``None``. + + This method is a coroutine. + + Example:: + + import asyncio + from microdot import Microdot + + app = Microdot() + + @app.route('/') + async def index(request): + return 'Hello, world!' + + async def main(): + await app.start_server(debug=True) + + asyncio.run(main()) + """ + self.debug = debug + + async def serve(reader, writer): + if not hasattr(writer, 'awrite'): # pragma: no cover + # CPython provides the awrite and aclose methods in 3.8+ + async def awrite(self, data): + self.write(data) + await self.drain() + + async def aclose(self): + self.close() + await self.wait_closed() + + from types import MethodType + writer.awrite = MethodType(awrite, writer) + writer.aclose = MethodType(aclose, writer) + + await self.handle_request(reader, writer) + + if self.debug: # pragma: no cover + print('Starting async server on {host}:{port}...'.format( + host=host, port=port)) + + try: + self.server = await asyncio.start_server(serve, host, port, + ssl=ssl) + except TypeError: # pragma: no cover + self.server = await asyncio.start_server(serve, host, port) + + while True: + try: + if hasattr(self.server, 'serve_forever'): # pragma: no cover + try: + await self.server.serve_forever() + except asyncio.CancelledError: + pass + await self.server.wait_closed() + break + except AttributeError: # pragma: no cover + # the task hasn't been initialized in the server object yet + # wait a bit and try again + await asyncio.sleep(0.1) + + def run(self, host='0.0.0.0', port=5000, debug=False, ssl=None): + """Start the web server. This function does not normally return, as + the server enters an endless listening loop. The :func:`shutdown` + function provides a method for terminating the server gracefully. + + :param host: The hostname or IP address of the network interface that + will be listening for requests. A value of ``'0.0.0.0'`` + (the default) indicates that the server should listen for + requests on all the available interfaces, and a value of + ``127.0.0.1`` indicates that the server should listen + for requests only on the internal networking interface of + the host. + :param port: The port number to listen for requests. The default is + port 5000. + :param debug: If ``True``, the server logs debugging information. The + default is ``False``. + :param ssl: An ``SSLContext`` instance or ``None`` if the server should + not use TLS. The default is ``None``. + + Example:: + + from microdot import Microdot + + app = Microdot() + + @app.route('/') + async def index(request): + return 'Hello, world!' + + app.run(debug=True) + """ + asyncio.run(self.start_server(host=host, port=port, debug=debug, + ssl=ssl)) # pragma: no cover + + def shutdown(self): + """Request a server shutdown. The server will then exit its request + listening loop and the :func:`run` function will return. This function + can be safely called from a route handler, as it only schedules the + server to terminate as soon as the request completes. + + Example:: + + @app.route('/shutdown') + def shutdown(request): + request.app.shutdown() + return 'The server is shutting down...' + """ + self.server.close() + + def find_route(self, req): + method = req.method.upper() + if method == 'OPTIONS' and self.options_handler: + return self.options_handler(req), '', None + if method == 'HEAD': + method = 'GET' + f = 404 + p = '' + s = None + for route_methods, route_pattern, route_handler, url_prefix, subapp \ + in self.url_map: + req.url_args = route_pattern.match(req.path) + if req.url_args is not None: + p = url_prefix + s = subapp + if method in route_methods: + f = route_handler + break + else: + f = 405 + return f, p, s + + def default_options_handler(self, req): + allow = [] + for route_methods, route_pattern, _, _, _ in self.url_map: + if route_pattern.match(req.path) is not None: + allow.extend(route_methods) + if 'GET' in allow: + allow.append('HEAD') + allow.append('OPTIONS') + return {'Allow': ', '.join(allow)} + + async def handle_request(self, reader, writer): + req = None + try: + req = await Request.create(self, reader, writer, + writer.get_extra_info('peername')) + except Exception as exc: # pragma: no cover + print_exception(exc) + + res = await self.dispatch_request(req) + try: + if res != Response.already_handled: # pragma: no branch + await res.write(writer) + await writer.aclose() + except OSError as exc: # pragma: no cover + if exc.errno in MUTED_SOCKET_ERRORS: + pass + else: + raise + if self.debug and req: # pragma: no cover + print('{method} {path} {status_code}'.format( + method=req.method, path=req.path, + status_code=res.status_code)) + + def get_request_handlers(self, req, attr, local_first=True): + handlers = getattr(self, attr + '_handlers') + local_handlers = getattr(req.subapp, attr + '_handlers') \ + if req and req.subapp else [] + return local_handlers + handlers if local_first \ + else handlers + local_handlers + + async def error_response(self, req, status_code, reason=None): + if req and req.subapp and status_code in req.subapp.error_handlers: + return await invoke_handler( + req.subapp.error_handlers[status_code], req) + elif status_code in self.error_handlers: + return await invoke_handler(self.error_handlers[status_code], req) + return reason or 'N/A', status_code + + async def dispatch_request(self, req): + after_request_handled = False + if req: + if req.content_length > req.max_content_length: + # the request body is larger than allowed + res = await self.error_response(req, 413, 'Payload too large') + else: + # find the route in the app's URL map + f, req.url_prefix, req.subapp = self.find_route(req) + + try: + res = None + if callable(f): + # invoke the before request handlers + for handler in self.get_request_handlers( + req, 'before_request', False): + res = await invoke_handler(handler, req) + if res: + break + + # invoke the endpoint handler + if res is None: + res = await invoke_handler(f, req, **req.url_args) + + # process the response + if isinstance(res, int): + # an integer response is taken as a status code + # with an empty body + res = '', res + if isinstance(res, tuple): + # handle a tuple response + if isinstance(res[0], int): + # a tuple that starts with an int has an empty + # body + res = ('', res[0], + res[1] if len(res) > 1 else {}) + body = res[0] + if isinstance(res[1], int): + # extract the status code and headers (if + # available) + status_code = res[1] + headers = res[2] if len(res) > 2 else {} + else: + # if the status code is missing, assume 200 + status_code = 200 + headers = res[1] + res = Response(body, status_code, headers) + elif not isinstance(res, Response): + # any other response types are wrapped in a + # Response object + res = Response(res) + + # invoke the after request handlers + for handler in self.get_request_handlers( + req, 'after_request', True): + res = await invoke_handler( + handler, req, res) or res + for handler in req.after_request_handlers: + res = await invoke_handler( + handler, req, res) or res + after_request_handled = True + elif isinstance(f, dict): + # the response from an OPTIONS request is a dict with + # headers + res = Response(headers=f) + else: + # if the route is not found, return a 404 or 405 + # response as appropriate + res = await self.error_response(req, f, 'Not found') + except HTTPException as exc: + # an HTTP exception was raised while handling this request + res = await self.error_response(req, exc.status_code, + exc.reason) + except Exception as exc: + # an unexpected exception was raised while handling this + # request + print_exception(exc) + + # invoke the error handler for the exception class if one + # exists + handler = None + res = None + if req.subapp and exc.__class__ in \ + req.subapp.error_handlers: + handler = req.subapp.error_handlers[exc.__class__] + elif exc.__class__ in self.error_handlers: + handler = self.error_handlers[exc.__class__] + else: + # walk up the exception class hierarchy to try to find + # a handler + for c in mro(exc.__class__)[1:]: + if req.subapp and c in req.subapp.error_handlers: + handler = req.subapp.error_handlers[c] + break + elif c in self.error_handlers: + handler = self.error_handlers[c] + break + if handler: + try: + res = await invoke_handler(handler, req, exc) + except Exception as exc2: # pragma: no cover + print_exception(exc2) + if res is None: + # if there is still no response, issue a 500 error + res = await self.error_response( + req, 500, 'Internal server error') + else: + # if the request could not be parsed, issue a 400 error + res = await self.error_response(req, 400, 'Bad request') + if isinstance(res, tuple): + res = Response(*res) + elif not isinstance(res, Response): + res = Response(res) + if not after_request_handled: + # if the request did not finish due to an error, invoke the after + # error request handler + for handler in self.get_request_handlers( + req, 'after_error_request', True): + res = await invoke_handler( + handler, req, res) or res + res.is_head = (req and req.method == 'HEAD') + return res + + +Response.already_handled = Response() + +abort = Microdot.abort +redirect = Response.redirect +send_file = Response.send_file diff --git a/microdot/multipart.py b/microdot/multipart.py new file mode 100644 index 0000000..62acc70 --- /dev/null +++ b/microdot/multipart.py @@ -0,0 +1,291 @@ +import os +from random import choice +from microdot import abort, iscoroutine, AsyncBytesIO +from microdot.helpers import wraps + + +class FormDataIter: + """Asynchronous iterator that parses a ``multipart/form-data`` body and + returns form fields and files as they are parsed. + + :param request: the request object to parse. + + Example usage:: + + from microdot.multipart import FormDataIter + + @app.post('/upload') + async def upload(request): + async for name, value in FormDataIter(request): + print(name, value) + + The iterator returns no values when the request has a content type other + than ``multipart/form-data``. For a file field, the returned value is of + type :class:`FileUpload`, which supports the + :meth:`read() ` and :meth:`save() ` + methods. Values for regular fields are provided as strings. + + The request body is read efficiently in chunks of size + :attr:`buffer_size `. On iterations in which a + file field is encountered, the file must be consumed before moving on to + the next iteration, as the internal stream stored in ``FileUpload`` + instances is invalidated at the end of the iteration. + """ + #: The size of the buffer used to read chunks of the request body. + buffer_size = 256 + + def __init__(self, request): + self.request = request + self.buffer = None + try: + mimetype, boundary = request.content_type.rsplit('; boundary=', 1) + except ValueError: + return # not a multipart request + if mimetype.split(';', 1)[0] == \ + 'multipart/form-data': # pragma: no branch + self.boundary = b'--' + boundary.encode() + self.extra_size = len(boundary) + 4 + self.buffer = b'' + + def __aiter__(self): + return self + + async def __anext__(self): + if self.buffer is None: + raise StopAsyncIteration + + # make sure we have consumed the previous entry + while await self._read_buffer(self.buffer_size) != b'': + pass + + # make sure we are at a boundary + s = self.buffer.split(self.boundary, 1) + if len(s) != 2 or s[0] != b'': + abort(400) # pragma: no cover + self.buffer = s[1] + if self.buffer[:2] == b'--': + # we have reached the end + raise StopAsyncIteration + elif self.buffer[:2] != b'\r\n': + abort(400) # pragma: no cover + self.buffer = self.buffer[2:] + + # parse the headers of this part + name = '' + filename = None + content_type = None + while True: + await self._fill_buffer() + lines = self.buffer.split(b'\r\n', 1) + if len(lines) != 2: + abort(400) # pragma: no cover + line, self.buffer = lines + if line == b'': + # we reached the end of the headers + break + header, value = line.decode().split(':', 1) + header = header.lower() + value = value.strip() + if header == 'content-disposition': + parts = value.split(';') + if len(parts) < 2 or parts[0] != 'form-data': + abort(400) # pragma: no cover + for part in parts[1:]: + part = part.strip() + if part.startswith('name="'): + name = part[6:-1] + elif part.startswith('filename="'): # pragma: no branch + filename = part[10:-1] + elif header == 'content-type': # pragma: no branch + content_type = value + + if filename is None: + # this is a regular form field, so we read the value + value = b'' + while True: + v = await self._read_buffer(self.buffer_size) + value += v + if len(v) < self.buffer_size: # pragma: no branch + break + return name, value.decode() + return name, FileUpload(filename, content_type, self._read_buffer) + + async def _fill_buffer(self): + self.buffer += await self.request.stream.read( + self.buffer_size + self.extra_size - len(self.buffer)) + + async def _read_buffer(self, n=-1): + data = b'' + while n == -1 or len(data) < n: + await self._fill_buffer() + s = self.buffer.split(self.boundary, 1) + data += s[0][:n] if n != -1 else s[0] + self.buffer = s[0][n:] if n != -1 else b'' + if len(s) == 2: # pragma: no branch + # the end of this part is in the buffer + if len(self.buffer) < 2: + # we have read all the way to the end of this part + data = data[:-(2 - len(self.buffer))] # remove last "\r\n" + self.buffer += self.boundary + s[1] + return data + return data + + +class FileUpload: + """Class that represents an uploaded file. + + :param filename: the name of the uploaded file. + :param content_type: the content type of the uploaded file. + :param read: a coroutine that reads from the uploaded file's stream. + + An uploaded file can be read from the stream using the :meth:`read()` + method or saved to a file using the :meth:`save()` method. + + Instances of this class do not normally need to be created directly. + """ + #: The size at which the file is copied to a temporary file. + max_memory_size = 1024 + + def __init__(self, filename, content_type, read): + self.filename = filename + self.content_type = content_type + self._read = read + self._close = None + + async def read(self, n=-1): + """Read up to ``n`` bytes from the uploaded file's stream. + + :param n: the maximum number of bytes to read. If ``n`` is -1 or not + given, the entire file is read. + """ + return await self._read(n) + + async def save(self, path_or_file): + """Save the uploaded file to the given path or file object. + + :param path_or_file: the path to save the file to, or a file object + to which the file is to be written. + + The file is read and written in chunks of size + :attr:`FormDataIter.buffer_size`. + """ + if isinstance(path_or_file, str): + f = open(path_or_file, 'wb') + else: + f = path_or_file + while True: + data = await self.read(FormDataIter.buffer_size) + if not data: + break + f.write(data) + if f != path_or_file: + f.close() + + async def copy(self, max_memory_size=None): + """Copy the uploaded file to a temporary file, to allow the parsing of + the multipart form to continue. + + :param max_memory_size: the maximum size of the file to keep in memory. + If not given, then the class attribute of the + same name is used. + """ + max_memory_size = max_memory_size or FileUpload.max_memory_size + buffer = await self.read(max_memory_size) + if len(buffer) < max_memory_size: + f = AsyncBytesIO(buffer) + self._read = f.read + return self + + # create a temporary file + while True: + tmpname = "".join([ + choice('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ') + for _ in range(12) + ]) + try: + f = open(tmpname, 'x+b') + except OSError as e: # pragma: no cover + if e.errno == 17: + # EEXIST + continue + elif e.errno == 2: + # ENOENT + # some MicroPython platforms do not support mode "x" + f = open(tmpname, 'w+b') + if f.read(1) != b'': + f.close() + continue + else: + raise + break + f.write(buffer) + await self.save(f) + f.seek(0) + + async def read(n=-1): + return f.read(n) + + async def close(): + f.close() + os.remove(tmpname) + + self._read = read + self._close = close + return self + + async def close(self): + """Close an open file. + + This method must be called to free memory or temporary files created by + the ``copy()`` method. + + Note that when using the ``@with_form_data`` decorator this method is + called automatically when the request ends. + """ + if self._close: + await self._close() + self._close = None + + +def with_form_data(f): + """Decorator that parses a ``multipart/form-data`` body and updates the + request object with the parsed form fields and files. + + Example usage:: + + from microdot.multipart import with_form_data + + @app.post('/upload') + @with_form_data + async def upload(request): + print('form fields:', request.form) + print('files:', request.files) + + Note: this decorator calls the :meth:`FileUpload.copy() + ` method on all uploaded files, so that + the request can be parsed in its entirety. The files are either copied to + memory or a temporary file, depending on their size. The temporary files + are automatically deleted when the request ends. + """ + @wraps(f) + async def wrapper(request, *args, **kwargs): + form = {} + files = {} + async for name, value in FormDataIter(request): + if isinstance(value, FileUpload): + files[name] = await value.copy() + else: + form[name] = value + if form or files: + request._form = form + request._files = files + try: + ret = f(request, *args, **kwargs) + if iscoroutine(ret): + ret = await ret + finally: + if request.files: + for file in request.files.values(): + await file.close() + return ret + return wrapper diff --git a/microdot/session.py b/microdot/session.py new file mode 100644 index 0000000..d8ce085 --- /dev/null +++ b/microdot/session.py @@ -0,0 +1,155 @@ +import jwt +from microdot.microdot import invoke_handler +from microdot.helpers import wraps + + +class SessionDict(dict): + """A session dictionary. + + The session dictionary is a standard Python dictionary that has been + extended with convenience ``save()`` and ``delete()`` methods. + """ + def __init__(self, request, session_dict): + super().__init__(session_dict) + self.request = request + + def save(self): + """Update the session cookie.""" + self.request.app._session.update(self.request, self) + + def delete(self): + """Delete the session cookie.""" + self.request.app._session.delete(self.request) + + +class Session: + """ + :param app: The application instance. + :param key: The secret key, as a string or bytes object. + """ + secret_key = None + + def __init__(self, app=None, secret_key=None, cookie_options=None): + self.secret_key = secret_key + self.cookie_options = cookie_options or {} + if app is not None: + self.initialize(app) + + def initialize(self, app, secret_key=None, cookie_options=None): + if secret_key is not None: + self.secret_key = secret_key + if cookie_options is not None: + self.cookie_options = cookie_options + if 'path' not in self.cookie_options: + self.cookie_options['path'] = '/' + if 'http_only' not in self.cookie_options: + self.cookie_options['http_only'] = True + app._session = self + + def get(self, request): + """Retrieve the user session. + + :param request: The client request. + + The return value is a session dictionary with the data stored in the + user's session, or ``{}`` if the session data is not available or + invalid. + """ + if not self.secret_key: + raise ValueError('The session secret key is not configured') + if hasattr(request.g, '_session'): + return request.g._session + session = request.cookies.get('session') + if session is None: + request.g._session = SessionDict(request, {}) + return request.g._session + request.g._session = SessionDict(request, self.decode(session)) + return request.g._session + + def update(self, request, session): + """Update the user session. + + :param request: The client request. + :param session: A dictionary with the update session data for the user. + + Applications would normally not call this method directly, instead they + would use the :meth:`SessionDict.save` method on the session + dictionary, which calls this method. For example:: + + @app.route('/') + @with_session + def index(request, session): + session['foo'] = 'bar' + session.save() + return 'Hello, World!' + + Calling this method adds a cookie with the updated session to the + request currently being processed. + """ + if not self.secret_key: + raise ValueError('The session secret key is not configured') + + encoded_session = self.encode(session) + + @request.after_request + def _update_session(request, response): + response.set_cookie('session', encoded_session, + **self.cookie_options) + return response + + def delete(self, request): + """Remove the user session. + + :param request: The client request. + + Applications would normally not call this method directly, instead they + would use the :meth:`SessionDict.delete` method on the session + dictionary, which calls this method. For example:: + + @app.route('/') + @with_session + def index(request, session): + session.delete() + return 'Hello, World!' + + Calling this method adds a cookie removal header to the request + currently being processed. + """ + @request.after_request + def _delete_session(request, response): + response.delete_cookie('session', **self.cookie_options) + return response + + def encode(self, payload, secret_key=None): + return jwt.encode(payload, secret_key or self.secret_key, + algorithm='HS256') + + def decode(self, session, secret_key=None): + try: + payload = jwt.decode(session, secret_key or self.secret_key, + algorithms=['HS256']) + except jwt.exceptions.PyJWTError: # pragma: no cover + return {} + return payload + + +def with_session(f): + """Decorator that passes the user session to the route handler. + + The session dictionary is passed to the decorated function as an argument + after the request object. Example:: + + @app.route('/') + @with_session + def index(request, session): + return 'Hello, World!' + + Note that the decorator does not save the session. To update the session, + call the :func:`session.save() ` method. + """ + @wraps(f) + async def wrapper(request, *args, **kwargs): + return await invoke_handler( + f, request, request.app._session.get(request), *args, **kwargs) + + return wrapper diff --git a/microdot/sse.py b/microdot/sse.py new file mode 100644 index 0000000..6376ee0 --- /dev/null +++ b/microdot/sse.py @@ -0,0 +1,126 @@ +import asyncio +from microdot.helpers import wraps + +try: + import orjson as json +except ImportError: + import json + + +class SSE: + """Server-Sent Events object. + + An object of this class is sent to handler functions to manage the SSE + connection. + """ + def __init__(self): + self.event = asyncio.Event() + self.queue = [] + + async def send(self, data, event=None, event_id=None): + """Send an event to the client. + + :param data: the data to send. It can be given as a string, bytes, dict + or list. Dictionaries and lists are serialized to JSON. + Any other types are converted to string before sending. + :param event: an optional event name, to send along with the data. If + given, it must be a string. + :param event_id: an optional event id, to send along with the data. If + given, it must be a string. + """ + if isinstance(data, (dict, list)): + data = json.dumps(data) + if isinstance(data, str): + data = data.encode() + elif not isinstance(data, bytes): + data = str(data).encode() + data = b'data: ' + data + b'\n\n' + if event_id: + data = b'id: ' + event_id.encode() + b'\n' + data + if event: + data = b'event: ' + event.encode() + b'\n' + data + self.queue.append(data) + self.event.set() + + +def sse_response(request, event_function, *args, **kwargs): + """Return a response object that initiates an event stream. + + :param request: the request object. + :param event_function: an asynchronous function that will send events to + the client. The function is invoked with ``request`` + and an ``sse`` object. The function should use + ``sse.send()`` to send events to the client. + :param args: additional positional arguments to be passed to the response. + :param kwargs: additional keyword arguments to be passed to the response. + + This is a low-level function that can be used to implement a custom SSE + endpoint. In general the :func:`microdot.sse.with_sse` decorator should be + used instead. + """ + sse = SSE() + + async def sse_task_wrapper(): + try: + await event_function(request, sse, *args, **kwargs) + except asyncio.CancelledError: # pragma: no cover + pass + except Exception as exc: + # the SSE task raised an exception so we need to pass it to the + # main route so that it is re-raised there + sse.queue.append(exc) + sse.event.set() + + task = asyncio.create_task(sse_task_wrapper()) + + class sse_loop: + def __aiter__(self): + return self + + async def __anext__(self): + event = None + while sse.queue or not task.done(): + try: + event = sse.queue.pop(0) + break + except IndexError: + await sse.event.wait() + sse.event.clear() + if isinstance(event, Exception): + # if the event is an exception we re-raise it here so that it + # can be handled appropriately + raise event + elif event is None: + raise StopAsyncIteration + return event + + async def aclose(self): + task.cancel() + + return sse_loop(), 200, {'Content-Type': 'text/event-stream'} + + +def with_sse(f): + """Decorator to make a route a Server-Sent Events endpoint. + + This decorator is used to define a route that accepts SSE connections. The + route then receives a sse object as a second argument that it can use to + send events to the client:: + + @app.route('/events') + @with_sse + async def events(request, sse): + # send an unnamed event with string data + await sse.send('hello') + + # send an unnamed event with JSON data + await sse.send({'foo': 'bar'}) + + # send a named event + await sse.send('hello', event='greeting') + """ + @wraps(f) + async def sse_handler(request, *args, **kwargs): + return sse_response(request, f, *args, **kwargs) + + return sse_handler diff --git a/microdot/utemplate.py b/microdot/utemplate.py new file mode 100644 index 0000000..16d0398 --- /dev/null +++ b/microdot/utemplate.py @@ -0,0 +1,70 @@ +from utemplate import recompile + +_loader = None + + +class Template: + """A template object. + + :param template: The filename of the template to render, relative to the + configured template directory. + """ + @classmethod + def initialize(cls, template_dir='templates', + loader_class=recompile.Loader): + """Initialize the templating subsystem. + + :param template_dir: the directory where templates are stored. This + argument is optional. The default is to load + templates from a *templates* subdirectory. + :param loader_class: the ``utemplate.Loader`` class to use when loading + templates. This argument is optional. The default + is the ``recompile.Loader`` class, which + automatically recompiles templates when they + change. + """ + global _loader + _loader = loader_class(None, template_dir) + + def __init__(self, template): + if _loader is None: # pragma: no cover + self.initialize() + #: The name of the template + self.name = template + self.template = _loader.load(template) + + def generate(self, *args, **kwargs): + """Return a generator that renders the template in chunks, with the + given arguments.""" + return self.template(*args, **kwargs) + + def render(self, *args, **kwargs): + """Render the template with the given arguments and return it as a + string.""" + return ''.join(self.generate(*args, **kwargs)) + + def generate_async(self, *args, **kwargs): + """Return an asynchronous generator that renders the template in + chunks, using the given arguments.""" + class sync_to_async_iter(): + def __init__(self, iter): + self.iter = iter + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self.iter) + except StopIteration: + raise StopAsyncIteration + + return sync_to_async_iter(self.generate(*args, **kwargs)) + + async def render_async(self, *args, **kwargs): + """Render the template with the given arguments asynchronously and + return it as a string.""" + response = '' + async for chunk in self.generate_async(*args, **kwargs): + response += chunk + return response diff --git a/requests/__init__.mpy b/requests/__init__.mpy new file mode 100644 index 0000000000000000000000000000000000000000..d9e62348da624d37ee55ef917d4efd2b5a5beac6 GIT binary patch literal 2434 zcmZWp+i%;}89$;dSyp6)qG(#NWt$a`C|SP9mKAk%mc*tgS+?ZZy85yOEJM+eR$Ej_ zdaakWp_mQW_OSn9!2STo(b}ZD1RIKN-P&$LHmoU@6nov{9)>>dP`0}S6Cs}O{J!t| zeOGYZ2_|-ZQB{8PP}X!UE(pz5Qx}BU_7S+~(G=+oSqFGY*4j!-lYwV$c7AR)0la4p zpyOYd#BViQq9!$)pfgF{N{GTOQIk_i;64R$PSlzbediruGJ|$^L)Y7jz+)ZSB5>;a z8o(9-bz9cOdR=X{_7p&sH_~~bTq$Li*8!=?y52mHl|vnPC8gDpB^{9Eav`m>bop%^ zu$!fg%0_ylAXExvAzfI_+^7H|ohem>8yhz=fcmzqHusK%1Enql>Md1l3ukqWYqE-z zLPOE?mUtiobJdkYGPv%pA0D(d;QE%Pv{2N>Ve1XK4oDQz3g1Lv=VE8ZJvZncbc4=a z(`QD^v^n2-)47tV07}zE{ZJDmE22l0#U4vT7VEN#NG$=la+zfwxY|}Y*XBmK0-SAB zp>wlZ0q(Y_OAX-OT&|>Zz*Cp^WgQXmnL?(L0gO`>-vYSYLiEknZQvCoQEH&{tnTP? zO9!1hu3l%{`-)iC03`^hN)2T#2p~Rl8E>mfTUPZW;5m@>hEfO4LvY^?JkEWx9 z4}eQN)EmHi-uaQXsA{rsUPC8t#+A<1&XypWeryjo9e3VYL=$!lE+aQZZGK1Wb;&_h zTiRi}t*E+Or%_y8k(B*K8aRDnjgF&#u2R{I&(ESx0*9wJ(w06`Icm$mfs5_-epBi( z#e0)*_$sQ{()MI^CUaiC!+Y*kkXyG=&Q8qY#}J{jFZ{N}J7xu;WmQ$wMVb@q<=%J- z&6d3A4r{MJKJ3KGe0GtB_7#9Wf=O%<&K9j?Hso&Ks|@V)QO zLEGLB=^y{JPXF&QilZ2XRT2QVCo#%i?EOly{_=v2awITnAz`N?i9YIb!a=>7a8j=& z`l%~;AC)Oeu6scjWpB{+rTXhXa5rdQmRiB>R4xIis|g#$`c@NI(nc+#(~izGVquDs zqq8*Py}G^g`VtJ^KMq{Or;h{Ii^qYbqGSS}m)$np;V3-zU*iktMY1WdgS6L9E$)FC zVtL!}MMC*#_}ITxe9rH#Ua;)A6tFm6bh(KPq>k&7E4^)u5~XAR^0EJV5|cX7SKQ3D zF`hxSGrzu1@<@}iNxy4<@{{k|{)0Kz@&zWbeq&?Pmm)- zsdOx|a6Pp-9S((K@%2iCV^${T3hPl8k#|u&nQ8DLm+g|y6!@_C(Mi@%6i@K}V$z-+ zup}3fp8Vmx;`1kb8|6Rn$M@#2k3V4m>HftpQ`o0O>F4ByU!`otyR%po_aw1T@j=p! zfG%UJ<^1hZfgX2hNyZIRRJ>^$I#wP3M zzWEp@p7QG`sGlW6_pTsll_)dV<9{b@GEPF@Gky#DPyn)nXfVX?Co0gNw4oCthoBGo z4FdC1sXlfH4jEN;sK`HLiAg-5^Rt^(1LaqIUP!xFtM%YQ%Q&aIGukW9*2P4Mgeg^wM&j30*j1 z)zrfVkid~&i6{I&jSu+08C^2MIEyI^hfn##6CuyCqdglpQH7&Pn^l^1hodj?1S|}X zT9ifPGI`{LJDrehmiADj7rmN-&5| zCoyY#xW5=Zn~TWiw??RwgJWCY_&n4@kfZjGohI2^N5<(AN9@f@c;EIsG<&SRYxc)p+`>;#-3XHo!CaHPiN9>O4` zlGqOS@ZB|hi`e>tU$t6leSzvXAMr + + + + + +

Last update: {{c["ura"]["gmtime"]}} GMT

+{% for stop_id in c["stops"] %} + {% set s = c["stops"][stop_id] %} +

{{s["name"]}}

+
(Towards: {{s["towards"]}})
+ {% for m in s["messages"] %} +

{{m}}

+ {% endfor %} + {% for line_name in s["lines"] %} +

{{line_name}}

+
    + {% for pred in s["lines"][line_name] %} +
  1. {{pred["est_due"]}} ➡ {{pred["destination"]}} + {% endfor %} +
+ {% endfor %} +{% endfor %} + + \ No newline at end of file diff --git a/tflcountdown/__init__.py b/tflcountdown/__init__.py new file mode 100644 index 0000000..0078625 --- /dev/null +++ b/tflcountdown/__init__.py @@ -0,0 +1,149 @@ +import json +import time +import urequests + +class TflCountdown: + API_URL = "https://countdown.api.tfl.gov.uk/interfaces/ura/instant_V1" + FIELD_NAMES = [ + # ResponseType 0 - Stop + ["ResponseType", "StopPointName", "StopID", "StopCode1", "StopCode2", "StopPointType", "Towards", "Bearing", "StopPointIndicator", "StopPointState", "Latitude", "Longitude"], + # ResponseType 1 - Prediction + ["ResponseType", "StopPointName", "StopID", "StopCode1", "StopCode2", "StopPointType", "Towards", "Bearing", "StopPointIndicator", "StopPointState", "Latitude", "Longitude", "VisitNumber", "LineID", "LineName", "DirectionID", "DestinationText", "DestinationName", "VehicleID", "TripID", "RegistrationNumber", "EstimatedTime", "ExpireTime"], + # ResponseType 2 - Flexible Message + ["ResponseType", "StopPointName", "StopID", "StopCode1", "StopCode2", "StopPointType", "Towards", "Bearing", "StopPointIndicator", "StopPointState", "Latitude", "Longitude", "MessageUUID", "MessageType", "MessagePriority", "MessageText", "StartTime", "ExpireTime"], + # ResponseType 3 - Baseversion + ["ResponseType", "Version"], + # ResponseType 4 - URA Version + ["ResponseType", "Version", "TimeStamp"] + ] + + def __init__(self, api_key: str): + self.api_key = api_key + self.return_list = ["StopPointName", "StopID", "Towards", "LineName", "DestinationText", "EstimatedTime", "MessageText"] + self.stop_ids = ["1597", "1598", "11333", "11334"] + self.time_now = time.gmtime() * 1000 + + def get_countdown(self, stop_ids: list = ["1598", "11333"], line_ids: list = []): + url = self.API_URL + + params = { + "StopAlso": "true", + "ReturnList": ",".join(self.return_list), + "StopID": ",".join(stop_ids), + "LineID": ",".join(line_ids) + } + url += self.get_query(params) + print(url) + result = urequests.get(url) + return result + + def get_query(self, params: dict): + query_str = "?" + for k in params: + if params[k] == "": + continue + query_str += f"&{k}={params[k]}" + return query_str + + def get_field(self, msg: list, field_name: str): + resp_type = msg[0] + try: + full_fields = self.FIELD_NAMES[resp_type] + except: + print(f"Unknown ResponseType: {resp_type}") + return None + if resp_type in [4]: + return_fields = full_fields + else: + return_fields = ["ResponseType"] + [f for f in self.return_list if f in full_fields] + #print(repr(return_fields)) + #print(repr(msg)) + try: + return_idx = return_fields.index(field_name) + except: + return None + return msg[return_idx] + + def strftime(self, time_tuple: tuple[int, ...], date: bool = True, time: bool = True): + result = "" + if date: + result += f"{time_tuple[0]}-{time_tuple[1]:02}-{time_tuple[2]:02}" + if time: result += " " + if time: + result += f"{time_tuple[3]:02}:{time_tuple[4]:02}:{time_tuple[5]:02}" + return result + + def strfstamp(self, gmstamp_ms: int, return_date: bool = True, return_time: bool = True): + tm_obj = time.gmtime(gmstamp_ms/1000) + return self.strftime(tm_obj, return_date, return_time) + + def get_due(self, gmstamp_ms: int): + diff = gmstamp_ms - self.time_now + minutes = round(diff / 60000) + return f"{minutes}min" + + def parse_countdown(self, ctdn_response: str): + """ + https://content.tfl.gov.uk/tfl-live-bus-river-bus-arrivals-api-documentation.pdf + + Optimised for a value of: + "ReturnList": "StopPointName,StopID,LineName,DestinationText,EstimatedTime,MessageText" + """ + result = {} + lines = ctdn_response.split("\r\n") + for l in lines: + ld = json.loads(l) + print(repr(ld)) + if ld[0] == 0: + # Stop record + stop_id = self.get_field(ld, "StopID") + if not "stops" in result: + result["stops"] = {} + result["stops"][stop_id] = { + "id": stop_id, + "name": self.get_field(ld, "StopPointName"), + "towards": self.get_field(ld, "Towards"), + "lines": {}, + "messages": [] + } + elif ld[0] == 1: + # Prediction record + stop_id = self.get_field(ld, "StopID") + line_no = self.get_field(ld, "LineName") + destination = self.get_field(ld, "DestinationText") + est_stamp = int(self.get_field(ld, "EstimatedTime")) + if not line_no in result["stops"][stop_id]["lines"]: + result["stops"][stop_id]["lines"][line_no] = [] + result["stops"][stop_id]["lines"][line_no].append({ + "destination": destination, + "est_gmstamp_ms": est_stamp, + "est_gmtime": self.strfstamp(est_stamp), + "est_due": self.get_due(est_stamp) + }) + elif ld[0] == 2: + # Flexible Message record + stop_id = self.get_field(ld, "StopID") + msg = self.get_field(ld, "MessageText") + result["stops"][stop_id]["messages"].append(msg) + elif ld[0] == 3: + # Baseversion record + pass + elif ld[0] == 4: + # URA Version record + ura_stamp = int(self.get_field(ld, "TimeStamp")) + # Use as reference time - perfect for embedded systems without RTC + self.time_now = ura_stamp + result["ura"] = { + "version": self.get_field(ld, "Version"), + "gmstamp_ms": ura_stamp, + "gmtime": self.strfstamp(ura_stamp) + } + else: + print(f"Unsupported ResponseType: {ld[0]}") + + # Sort arrivals by estimated time + for stop_id in result["stops"]: + for line_id in result["stops"][stop_id]["lines"]: + result["stops"][stop_id]["lines"][line_id] = sorted(result["stops"][stop_id]["lines"][line_id], key=lambda x: x["est_gmstamp_ms"]) + + return result diff --git a/urequests.mpy b/urequests.mpy new file mode 100644 index 0000000000000000000000000000000000000000..8677c9b5bf5731806331ed76eb794c64d859abc5 GIT binary patch literal 101 zcmeZeW02=ykSHxmEi6qfE-BV4sAS;hGSIWoGhmPvfQd4Q#mASW78Pga=f%e}a3q$L y6frdFH}WuWF(xRaBp5M>F$*w>u?8kHHcBXRF|#E&xHbezvm6v+-?+hqDH#Ar)*0IX literal 0 HcmV?d00001 diff --git a/utemplate/compiled.py b/utemplate/compiled.py new file mode 100644 index 0000000..82237a4 --- /dev/null +++ b/utemplate/compiled.py @@ -0,0 +1,14 @@ +class Loader: + + def __init__(self, pkg, dir): + if dir == ".": + dir = "" + else: + dir = dir.replace("/", ".") + "." + if pkg and pkg != "__main__": + dir = pkg + "." + dir + self.p = dir + + def load(self, name): + name = name.replace(".", "_") + return __import__(self.p + name, None, None, (name,)).render diff --git a/utemplate/recompile.py b/utemplate/recompile.py new file mode 100644 index 0000000..c728c7c --- /dev/null +++ b/utemplate/recompile.py @@ -0,0 +1,21 @@ +# (c) 2014-2020 Paul Sokolovsky. MIT license. +try: + from uos import stat, remove +except: + from os import stat, remove +from . import source + + +class Loader(source.Loader): + + def load(self, name): + o_path = self.pkg_path + self.compiled_path(name) + i_path = self.pkg_path + self.dir + "/" + name + try: + o_stat = stat(o_path) + i_stat = stat(i_path) + if i_stat[8] > o_stat[8]: + # input file is newer, remove output to force recompile + remove(o_path) + finally: + return super().load(name) diff --git a/utemplate/source.py b/utemplate/source.py new file mode 100644 index 0000000..429f589 --- /dev/null +++ b/utemplate/source.py @@ -0,0 +1,188 @@ +# (c) 2014-2019 Paul Sokolovsky. MIT license. +from . import compiled + + +class Compiler: + + START_CHAR = "{" + STMNT = "%" + STMNT_END = "%}" + EXPR = "{" + EXPR_END = "}}" + + def __init__(self, file_in, file_out, indent=0, seq=0, loader=None): + self.file_in = file_in + self.file_out = file_out + self.loader = loader + self.seq = seq + self._indent = indent + self.stack = [] + self.in_literal = False + self.flushed_header = False + self.args = "*a, **d" + + def indent(self, adjust=0): + if not self.flushed_header: + self.flushed_header = True + self.indent() + self.file_out.write("def render%s(%s):\n" % (str(self.seq) if self.seq else "", self.args)) + self.stack.append("def") + self.file_out.write(" " * (len(self.stack) + self._indent + adjust)) + + def literal(self, s): + if not s: + return + if not self.in_literal: + self.indent() + self.file_out.write('yield """') + self.in_literal = True + self.file_out.write(s.replace('"', '\\"')) + + def close_literal(self): + if self.in_literal: + self.file_out.write('"""\n') + self.in_literal = False + + def render_expr(self, e): + self.indent() + self.file_out.write('yield str(' + e + ')\n') + + def parse_statement(self, stmt): + tokens = stmt.split(None, 1) + if tokens[0] == "args": + if len(tokens) > 1: + self.args = tokens[1] + else: + self.args = "" + elif tokens[0] == "set": + self.indent() + self.file_out.write(stmt[3:].strip() + "\n") + elif tokens[0] == "include": + if not self.flushed_header: + # If there was no other output, we still need a header now + self.indent() + tokens = tokens[1].split(None, 1) + args = "" + if len(tokens) > 1: + args = tokens[1] + if tokens[0][0] == "{": + self.indent() + # "1" as fromlist param is uPy hack + self.file_out.write('_ = __import__(%s.replace(".", "_"), None, None, 1)\n' % tokens[0][2:-2]) + self.indent() + self.file_out.write("yield from _.render(%s)\n" % args) + return + + with self.loader.input_open(tokens[0][1:-1]) as inc: + self.seq += 1 + c = Compiler(inc, self.file_out, len(self.stack) + self._indent, self.seq) + inc_id = self.seq + self.seq = c.compile() + self.indent() + self.file_out.write("yield from render%d(%s)\n" % (inc_id, args)) + elif len(tokens) > 1: + if tokens[0] == "elif": + assert self.stack[-1] == "if" + self.indent(-1) + self.file_out.write(stmt + ":\n") + else: + self.indent() + self.file_out.write(stmt + ":\n") + self.stack.append(tokens[0]) + else: + if stmt.startswith("end"): + assert self.stack[-1] == stmt[3:] + self.stack.pop(-1) + elif stmt == "else": + assert self.stack[-1] == "if" + self.indent(-1) + self.file_out.write("else:\n") + else: + assert False + + def parse_line(self, l): + while l: + start = l.find(self.START_CHAR) + if start == -1: + self.literal(l) + return + self.literal(l[:start]) + self.close_literal() + sel = l[start + 1] + #print("*%s=%s=" % (sel, EXPR)) + if sel == self.STMNT: + end = l.find(self.STMNT_END) + assert end > 0 + stmt = l[start + len(self.START_CHAR + self.STMNT):end].strip() + self.parse_statement(stmt) + end += len(self.STMNT_END) + l = l[end:] + if not self.in_literal and l == "\n": + break + elif sel == self.EXPR: + # print("EXPR") + end = l.find(self.EXPR_END) + assert end > 0 + expr = l[start + len(self.START_CHAR + self.EXPR):end].strip() + self.render_expr(expr) + end += len(self.EXPR_END) + l = l[end:] + else: + self.literal(l[start]) + l = l[start + 1:] + + def header(self): + self.file_out.write("# Autogenerated file\n") + + def compile(self): + self.header() + for l in self.file_in: + self.parse_line(l) + self.close_literal() + return self.seq + + +class Loader(compiled.Loader): + + def __init__(self, pkg, dir): + super().__init__(pkg, dir) + self.dir = dir + if pkg == "__main__": + # if pkg isn't really a package, don't bother to use it + # it means we're running from "filesystem directory", not + # from a package. + pkg = None + + self.pkg_path = "" + if pkg: + p = __import__(pkg) + if isinstance(p.__path__, str): + # uPy + self.pkg_path = p.__path__ + else: + # CPy + self.pkg_path = p.__path__[0] + self.pkg_path += "/" + + def input_open(self, template): + path = self.pkg_path + self.dir + "/" + template + return open(path) + + def compiled_path(self, template): + return self.dir + "/" + template.replace(".", "_") + ".py" + + def load(self, name): + try: + return super().load(name) + except (OSError, ImportError): + pass + + compiled_path = self.pkg_path + self.compiled_path(name) + + f_in = self.input_open(name) + f_out = open(compiled_path, "w") + c = Compiler(f_in, f_out, loader=self) + c.compile() + f_in.close() + f_out.close() + return super().load(name)