Source code for aio_pika.message

import json
import time
from datetime import datetime, timedelta
from functools import singledispatch
from pprint import pformat
from types import TracebackType
from typing import (
    Any, Callable, Dict, Iterable, Iterator, List, MutableMapping, Optional,
    Type, TypeVar, Union,
)

import aiormq
from aiormq.abc import DeliveredMessage, FieldTable
from pamqp.common import FieldValue

from .abc import (
    MILLISECONDS, ZERO_TIME, AbstractChannel, AbstractIncomingMessage,
    AbstractMessage, AbstractProcessContext, DateType, DeliveryMode,
    HeadersPythonValues, HeadersType, NoneType,
)
from .exceptions import MessageProcessError
from .log import get_logger


log = get_logger(__name__)


def to_milliseconds(seconds: Union[float, int]) -> int:
    return int(seconds * MILLISECONDS)


@singledispatch
def encode_expiration(value: Any) -> Optional[str]:
    raise ValueError("Invalid timestamp type: %r" % type(value), value)


@encode_expiration.register(datetime)
def encode_expiration_datetime(value: datetime) -> str:
    now = datetime.now(tz=value.tzinfo)
    return str(to_milliseconds((value - now).total_seconds()))


@encode_expiration.register(int)
@encode_expiration.register(float)
def encode_expiration_number(value: Union[int, float]) -> str:
    return str(to_milliseconds(value))


@encode_expiration.register(timedelta)
def encode_expiration_timedelta(value: timedelta) -> str:
    return str(int(value.total_seconds() * MILLISECONDS))


@encode_expiration.register(NoneType)       # type: ignore
def encode_expiration_none(_: Any) -> None:
    return None


@singledispatch
def decode_expiration(t: Any) -> Optional[float]:
    raise ValueError("Invalid expiration type: %r" % type(t), t)


@decode_expiration.register(time.struct_time)
def decode_expiration_struct_time(t: time.struct_time) -> float:
    return (datetime(*t[:7]) - ZERO_TIME).total_seconds()


@decode_expiration.register(str)
def decode_expiration_str(t: str) -> float:
    return float(t)


@singledispatch
def encode_timestamp(value: Any) -> Optional[datetime]:
    raise ValueError("Invalid timestamp type: %r" % type(value), value)


@encode_timestamp.register(time.struct_time)
def encode_timestamp_struct_time(value: time.struct_time) -> datetime:
    return datetime(*value[:6])


@encode_timestamp.register(datetime)
def encode_timestamp_datetime(value: datetime) -> datetime:
    return value


@encode_timestamp.register(float)
@encode_timestamp.register(int)
def encode_timestamp_number(value: Union[int, float]) -> datetime:
    return datetime.utcfromtimestamp(value)


@encode_timestamp.register(timedelta)
def encode_timestamp_timedelta(value: timedelta) -> datetime:
    return datetime.utcnow() + value


@encode_timestamp.register(NoneType)        # type: ignore
def encode_timestamp_none(_: Any) -> None:
    return None


@singledispatch
def decode_timestamp(value: Any) -> Optional[datetime]:
    raise ValueError("Invalid timestamp type: %r" % type(value), value)


@decode_timestamp.register(datetime)
def decode_timestamp_datetime(value: datetime) -> datetime:
    return value


@decode_timestamp.register(float)
@decode_timestamp.register(int)
def decode_timestamp_number(value: Union[float, int]) -> datetime:
    return datetime.utcfromtimestamp(value)


@decode_timestamp.register(time.struct_time)
def decode_timestamp_struct_time(value: time.struct_time) -> datetime:
    return datetime(*value[:6])


@decode_timestamp.register(NoneType)    # type: ignore
def decode_timestamp_none(_: Any) -> None:
    return None


V = TypeVar("V")
D = TypeVar("D")
T = TypeVar("T")


def optional(
    value: V,
    func: Union[Callable[[V], T], Type[T]],
    default: D = None,
) -> Union[T, D]:
    return func(value) if value else default    # type: ignore


class HeaderProxy(MutableMapping):
    def __init__(self, headers: FieldTable):
        self._headers: FieldTable = headers
        self._cache: Dict[str, Any] = {}

    def __getitem__(self, k: str) -> FieldValue:
        if k not in self._headers:
            raise KeyError(k)

        if k not in self._cache:
            value = self._headers[k]

            if isinstance(value, bytes):
                self._cache[k] = value.decode()
            else:
                self._cache[k] = value

        return self._cache[k]

    def __delitem__(self, key: str) -> None:
        del self._headers[key]

    def __setitem__(self, key: str, value: FieldValue) -> None:
        self._headers[key] = header_converter(value)
        self._cache.pop(key, None)

    def __len__(self) -> int:
        return len(self._headers)

    def __iter__(self) -> Iterator[str]:
        yield from self._headers


@singledispatch
def header_converter(value: Any) -> FieldValue:
    return bytearray(
        json.dumps(
            value,
            separators=(",", ":"),
            ensure_ascii=False,
            default=repr,
        ).encode(),
    )


@header_converter.register(NoneType)        # type: ignore
@header_converter.register(bytearray)
@header_converter.register(str)
@header_converter.register(datetime)
@header_converter.register(time.struct_time)
@header_converter.register(list)
@header_converter.register(int)
def header_converter_native(v: T) -> T:
    return v


@header_converter.register(bytes)
def header_converter_bytes(v: bytes) -> bytearray:
    return bytearray(v)


@header_converter.register(set)
@header_converter.register(tuple)
@header_converter.register(frozenset)
def header_converter_iterable(v: Iterable[T]) -> List[T]:
    return header_converter(list(v))        # type: ignore


def format_headers(d: Optional[HeadersType]) -> FieldTable:
    ret: FieldTable = {}

    if not d:
        return ret

    for key, value in d.items():
        ret[key] = header_converter(value)
    return ret


[docs]class Message(AbstractMessage): """ AMQP message abstraction """ __slots__ = ( "app_id", "body", "body_size", "content_encoding", "content_type", "correlation_id", "delivery_mode", "expiration", "_headers", "headers_raw", "message_id", "priority", "reply_to", "timestamp", "type", "user_id", "__lock", ) def __init__( self, body: bytes, *, headers: Optional[HeadersType] = None, content_type: Optional[str] = None, content_encoding: Optional[str] = None, delivery_mode: Union[DeliveryMode, int, None] = None, priority: Optional[int] = None, correlation_id: Optional[str] = None, reply_to: Optional[str] = None, expiration: Optional[DateType] = None, message_id: Optional[str] = None, timestamp: Optional[DateType] = None, type: Optional[str] = None, user_id: Optional[str] = None, app_id: Optional[str] = None ): """ Creates a new instance of Message :param body: message body :param headers: message headers :param content_type: content type :param content_encoding: content encoding :param delivery_mode: delivery mode :param priority: priority :param correlation_id: correlation id :param reply_to: reply to :param expiration: expiration in seconds (or datetime or timedelta) :param message_id: message id :param timestamp: timestamp :param type: type :param user_id: user id :param app_id: app id """ self.__lock = False self.body = body if isinstance(body, bytes) else bytes(body) self.body_size = len(self.body) if self.body else 0 self.headers_raw: FieldTable = format_headers(headers) self._headers: HeadersType = HeaderProxy(self.headers_raw) self.content_type = content_type self.content_encoding = content_encoding self.delivery_mode: DeliveryMode = DeliveryMode( optional( delivery_mode, int, DeliveryMode.NOT_PERSISTENT, ), ) self.priority = optional(priority, int, 0) self.correlation_id = optional(correlation_id, str) self.reply_to = optional(reply_to, str) self.expiration = expiration self.message_id = optional(message_id, str) self.timestamp = encode_timestamp(timestamp) self.type = optional(type, str) self.user_id = optional(user_id, str) self.app_id = optional(app_id, str) @property def headers(self) -> HeadersType: return self._headers @headers.setter def headers(self, value: Dict[str, HeadersPythonValues]) -> None: self.headers_raw = format_headers(value) @staticmethod def _as_bytes(value: Any) -> bytes: if isinstance(value, bytes): return value elif isinstance(value, str): return value.encode() elif value is None: return b"" else: return str(value).encode()
[docs] def info(self) -> dict: """ Create a dict with message attributes :: { "body_size": 100, "headers": {}, "content_type": "text/plain", "content_encoding": "", "delivery_mode": DeliveryMode.NOT_PERSISTENT, "priority": 0, "correlation_id": "", "reply_to": "", "expiration": "", "message_id": "", "timestamp": "", "type": "", "user_id": "", "app_id": "", } """ return { "body_size": self.body_size, "headers": self.headers_raw, "content_type": self.content_type, "content_encoding": self.content_encoding, "delivery_mode": self.delivery_mode, "priority": self.priority, "correlation_id": self.correlation_id, "reply_to": self.reply_to, "expiration": self.expiration, "message_id": self.message_id, "timestamp": decode_timestamp(self.timestamp), "type": str(self.type), "user_id": self.user_id, "app_id": self.app_id, }
@property def locked(self) -> bool: """ is message locked :return: :class:`bool` """ return bool(self.__lock) @property def properties(self) -> aiormq.spec.Basic.Properties: """ Build :class:`aiormq.spec.Basic.Properties` object """ return aiormq.spec.Basic.Properties( content_type=self.content_type, content_encoding=self.content_encoding, headers=self.headers_raw, delivery_mode=self.delivery_mode, priority=self.priority, correlation_id=self.correlation_id, reply_to=self.reply_to, expiration=encode_expiration(self.expiration), message_id=self.message_id, timestamp=self.timestamp, message_type=self.type, user_id=self.user_id, app_id=self.app_id, ) def __repr__(self) -> str: return "{name}:{repr}".format( name=self.__class__.__name__, repr=pformat(self.info()), ) def __setattr__(self, key: str, value: FieldValue) -> None: if not key.startswith("_") and self.locked: raise ValueError("Message is locked") return super().__setattr__(key, value) def __iter__(self) -> Iterator[int]: return iter(self.body)
[docs] def lock(self) -> None: """ Set lock flag to `True`""" self.__lock = True
def __copy__(self) -> "Message": return Message( body=self.body, headers=self.headers_raw, content_encoding=self.content_encoding, content_type=self.content_type, delivery_mode=self.delivery_mode, priority=self.priority, correlation_id=self.correlation_id, reply_to=self.reply_to, expiration=self.expiration, message_id=self.message_id, timestamp=self.timestamp, type=self.type, user_id=self.user_id, app_id=self.app_id, )
[docs]class IncomingMessage(Message, AbstractIncomingMessage): """ Incoming message is seems like Message but has additional methods for message acknowledgement. Depending on the acknowledgement mode used, RabbitMQ can consider a message to be successfully delivered either immediately after it is sent out (written to a TCP socket) or when an explicit ("manual") client acknowledgement is received. Manually sent acknowledgements can be positive or negative and use one of the following protocol methods: * basic.ack is used for positive acknowledgements * basic.nack is used for negative acknowledgements (note: this is a RabbitMQ extension to AMQP 0-9-1) * basic.reject is used for negative acknowledgements but has one limitations compared to basic.nack Positive acknowledgements simply instruct RabbitMQ to record a message as delivered. Negative acknowledgements with basic.reject have the same effect. The difference is primarily in the semantics: positive acknowledgements assume a message was successfully processed while their negative counterpart suggests that a delivery wasn't processed but still should be deleted. """ __slots__ = ( "_loop", "__channel", "cluster_id", "consumer_tag", "delivery_tag", "exchange", "routing_key", "redelivered", "__no_ack", "__processed", "message_count", ) def __init__(self, message: DeliveredMessage, no_ack: bool = False): """ Create an instance of :class:`IncomingMessage` """ self.__channel = message.channel self.__no_ack = no_ack self.__processed = False expiration = None if message.header.properties.expiration: expiration = decode_expiration( message.header.properties.expiration, ) super().__init__( body=message.body, content_type=message.header.properties.content_type, content_encoding=message.header.properties.content_encoding, headers=message.header.properties.headers, delivery_mode=message.header.properties.delivery_mode, priority=message.header.properties.priority, correlation_id=message.header.properties.correlation_id, reply_to=message.header.properties.reply_to, expiration=expiration / 1000.0 if expiration else None, message_id=message.header.properties.message_id, timestamp=decode_timestamp(message.header.properties.timestamp), type=message.header.properties.message_type, user_id=message.header.properties.user_id, app_id=message.header.properties.app_id, ) self.cluster_id = message.header.properties.cluster_id self.consumer_tag = message.consumer_tag self.delivery_tag = message.delivery_tag self.exchange = message.exchange self.message_count = message.message_count self.redelivered = message.redelivered self.routing_key = message.routing_key if no_ack or not self.delivery_tag: self.lock() self.__processed = True @property def channel(self) -> aiormq.abc.AbstractChannel: return self.__channel
[docs] def process( self, requeue: bool = False, reject_on_redelivered: bool = False, ignore_processed: bool = False, ) -> AbstractProcessContext: """ Context manager for processing the message >>> async def on_message_received(message: IncomingMessage): ... async with message.process(): ... # When exception will be raised ... # the message will be rejected ... print(message.body) Example with ignore_processed=True >>> async def on_message_received(message: IncomingMessage): ... async with message.process(ignore_processed=True): ... # Now (with ignore_processed=True) you may reject ... # (or ack) message manually too ... if True: # some reasonable condition here ... await message.reject() ... print(message.body) :param requeue: Requeue message when exception. :param reject_on_redelivered: When True message will be rejected only when message was redelivered. :param ignore_processed: Do nothing if message already processed """ return ProcessContext( self, requeue=requeue, reject_on_redelivered=reject_on_redelivered, ignore_processed=ignore_processed, )
[docs] async def ack(self, multiple: bool = False) -> None: """ Send basic.ack is used for positive acknowledgements .. note:: This method looks like a blocking-method, but actually it just sends bytes to the socket and doesn't require any responses from the broker. :param multiple: If set to True, the message's delivery tag is treated as "up to and including", so that multiple messages can be acknowledged with a single method. If set to False, the ack refers to a single message. :return: None """ if self.__no_ack: raise TypeError('Can\'t ack message with "no_ack" flag') if self.__processed: raise MessageProcessError("Message already processed", self) if self.delivery_tag is not None: await self.__channel.basic_ack( delivery_tag=self.delivery_tag, multiple=multiple, ) self.__processed = True if not self.locked: self.lock()
[docs] async def reject(self, requeue: bool = False) -> None: """ When `requeue=True` the message will be returned to queue. Otherwise message will be dropped. .. note:: This method looks like a blocking-method, but actually it just sends bytes to the socket and doesn't require any responses from the broker. :param requeue: bool """ if self.__no_ack: raise TypeError('This message has "no_ack" flag.') if self.__processed: raise MessageProcessError("Message already processed", self) if self.delivery_tag is not None: await self.__channel.basic_reject( delivery_tag=self.delivery_tag, requeue=requeue, ) self.__processed = True if not self.locked: self.lock()
async def nack( self, multiple: bool = False, requeue: bool = True, ) -> None: if not self.channel.connection.basic_nack: raise RuntimeError("Method not supported on server") if self.__no_ack: raise TypeError('Can\'t nack message with "no_ack" flag') if self.__processed: raise MessageProcessError("Message already processed", self) if self.delivery_tag is not None: await self.__channel.basic_nack( delivery_tag=self.delivery_tag, multiple=multiple, requeue=requeue, ) self.__processed = True if not self.locked: self.lock()
[docs] def info(self) -> dict: """ Method returns dict representation of the message """ info = super().info() info["cluster_id"] = self.cluster_id info["consumer_tag"] = self.consumer_tag info["delivery_tag"] = self.delivery_tag info["exchange"] = self.exchange info["redelivered"] = self.redelivered info["routing_key"] = self.routing_key return info
@property def processed(self) -> bool: return self.__processed
class ReturnedMessage(IncomingMessage): pass ReturnCallback = Callable[[AbstractChannel, ReturnedMessage], Any] class ProcessContext(AbstractProcessContext): def __init__( self, message: IncomingMessage, *, requeue: bool, reject_on_redelivered: bool, ignore_processed: bool ): self.message = message self.requeue = requeue self.reject_on_redelivered = reject_on_redelivered self.ignore_processed = ignore_processed async def __aenter__(self) -> IncomingMessage: return self.message async def __aexit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: if not exc_type: if not self.ignore_processed or not self.message.processed: await self.message.ack() return if not self.ignore_processed or not self.message.processed: if self.reject_on_redelivered and self.message.redelivered: if not self.message.channel.is_closed: log.info( "Message %r was redelivered and will be rejected", self.message, ) await self.message.reject(requeue=False) return log.warning( "Message %r was redelivered and reject is not sent " "since channel is closed", self.message, ) else: if not self.message.channel.is_closed: await self.message.reject(requeue=self.requeue) return log.warning("Reject is not sent since channel is closed") __all__ = "Message", "IncomingMessage", "ReturnedMessage",