Files
HeurAMS/webshare/_binary_encode.py
2025-07-23 13:55:24 +08:00

326 lines
7.5 KiB
Python

"""
An encoding / decoding format suitable for serializing data structures to binary.
This is based on https://en.wikipedia.org/wiki/Bencode with some extensions.
The following data types may be encoded:
- None
- int
- bool
- bytes
- str
- list
- tuple
- dict
"""
from __future__ import annotations
from typing import Any, Callable
class DecodeError(Exception):
"""A problem decoding data."""
def dump(data: object) -> bytes:
"""Encodes a data structure in to bytes.
Args:
data: Data structure
Returns:
A byte string encoding the data.
"""
def encode_none(_datum: None) -> bytes:
"""
Encodes a None value.
Args:
datum: Always None.
Returns:
None encoded.
"""
return b"N"
def encode_bool(datum: bool) -> bytes:
"""
Encode a boolean value.
Args:
datum: The boolean value to encode.
Returns:
The encoded bytes.
"""
return b"T" if datum else b"F"
def encode_int(datum: int) -> bytes:
"""
Encode an integer value.
Args:
datum: The integer value to encode.
Returns:
The encoded bytes.
"""
return b"i%ie" % datum
def encode_bytes(datum: bytes) -> bytes:
"""
Encode a bytes value.
Args:
datum: The bytes value to encode.
Returns:
The encoded bytes.
"""
return b"%i:%s" % (len(datum), datum)
def encode_string(datum: str) -> bytes:
"""
Encode a string value.
Args:
datum: The string value to encode.
Returns:
The encoded bytes.
"""
encoded_data = datum.encode("utf-8")
return b"s%i:%s" % (len(encoded_data), encoded_data)
def encode_list(datum: list) -> bytes:
"""
Encode a list value.
Args:
datum: The list value to encode.
Returns:
The encoded bytes.
"""
return b"l%se" % b"".join(encode(element) for element in datum)
def encode_tuple(datum: tuple) -> bytes:
"""
Encode a tuple value.
Args:
datum: The tuple value to encode.
Returns:
The encoded bytes.
"""
return b"t%se" % b"".join(encode(element) for element in datum)
def encode_dict(datum: dict) -> bytes:
"""
Encode a dictionary value.
Args:
datum: The dictionary value to encode.
Returns:
The encoded bytes.
"""
return b"d%se" % b"".join(
b"%s%s" % (encode(key), encode(value)) for key, value in datum.items()
)
ENCODERS: dict[type, Callable[[Any], Any]] = {
type(None): encode_none,
bool: encode_bool,
int: encode_int,
bytes: encode_bytes,
str: encode_string,
list: encode_list,
tuple: encode_tuple,
dict: encode_dict,
}
def encode(datum: object) -> bytes:
"""Recursively encode data.
Args:
datum: Data suitable for encoding.
Raises:
TypeError: If `datum` is not one of the supported types.
Returns:
Encoded data bytes.
"""
try:
decoder = ENCODERS[type(datum)]
except KeyError:
raise TypeError("Can't encode {datum!r}") from None
return decoder(datum)
return encode(data)
def load(encoded: bytes) -> object:
"""Load an encoded data structure from bytes.
Args:
encoded: Encoded data in bytes.
Raises:
DecodeError: If an error was encountered decoding the string.
Returns:
Decoded data.
"""
if not isinstance(encoded, bytes):
raise TypeError("must be bytes")
max_position = len(encoded)
position = 0
def get_byte() -> bytes:
"""Get an encoded byte and advance position.
Raises:
DecodeError: If the end of the data was reached
Returns:
A bytes object with a single byte.
"""
nonlocal position
if position >= max_position:
raise DecodeError("More data expected")
character = encoded[position : position + 1]
position += 1
return character
def peek_byte() -> bytes:
"""Get the byte at the current position, but don't advance position.
Returns:
A bytes object with a single byte.
"""
return encoded[position : position + 1]
def get_bytes(size: int) -> bytes:
"""Get a number of bytes of encode data.
Args:
size: Number of bytes to retrieve.
Raises:
DecodeError: If there aren't enough bytes.
Returns:
A bytes object.
"""
nonlocal position
bytes_data = encoded[position : position + size]
if len(bytes_data) != size:
raise DecodeError(b"Missing bytes in {bytes_data!r}")
position += size
return bytes_data
def decode_int() -> int:
"""Decode an int from the encoded data.
Returns:
An integer.
"""
int_bytes = b""
while (byte := get_byte()) != b"e":
int_bytes += byte
return int(int_bytes)
def decode_bytes(size_bytes: bytes) -> bytes:
"""Decode a bytes string from the encoded data.
Returns:
A bytes object.
"""
while (byte := get_byte()) != b":":
size_bytes += byte
bytes_string = get_bytes(int(size_bytes))
return bytes_string
def decode_string() -> str:
"""Decode a (utf-8 encoded) string from the encoded data.
Returns:
A string.
"""
size_bytes = b""
while (byte := get_byte()) != b":":
size_bytes += byte
bytes_string = get_bytes(int(size_bytes))
decoded_string = bytes_string.decode("utf-8", errors="replace")
return decoded_string
def decode_list() -> list[object]:
"""Decode a list.
Returns:
A list of data.
"""
elements: list[object] = []
add_element = elements.append
while peek_byte() != b"e":
add_element(decode())
get_byte()
return elements
def decode_tuple() -> tuple[object, ...]:
"""Decode a tuple.
Returns:
A tuple of decoded data.
"""
elements: list[object] = []
add_element = elements.append
while peek_byte() != b"e":
add_element(decode())
get_byte()
return tuple(elements)
def decode_dict() -> dict[object, object]:
"""Decode a dict.
Returns:
A dict of decoded data.
"""
elements: dict[object, object] = {}
add_element = elements.__setitem__
while peek_byte() != b"e":
add_element(decode(), decode())
get_byte()
return elements
DECODERS = {
b"i": decode_int,
b"s": decode_string,
b"l": decode_list,
b"t": decode_tuple,
b"d": decode_dict,
b"T": lambda: True,
b"F": lambda: False,
b"N": lambda: None,
}
def decode() -> object:
"""Recursively decode data.
Returns:
Decoded data.
"""
decoder = DECODERS.get(initial := get_byte(), None)
if decoder is None:
return decode_bytes(initial)
return decoder()
return decode()