Source code for aio_pika.patterns.rpc

import asyncio
import json
import logging
import pickle
import time
from enum import Enum
from functools import partial
from typing import Any, Callable, Dict, Hashable, Optional, TypeVar

from aio_pika.channel import Channel
from aio_pika.exceptions import DeliveryError
from aio_pika.exchange import ExchangeType
from aio_pika.message import (
    DeliveryMode, IncomingMessage, Message, ReturnedMessage,
)
from aio_pika.tools import shield
from aiormq.tools import awaitable

from .base import Base, Proxy


log = logging.getLogger(__name__)

R = TypeVar("R")
P = TypeVar("P")
CallbackType = Callable[[P], R]


class RPCMessageTypes(Enum):
    error = "error"
    result = "result"
    call = "call"


[docs]class RPC(Base): __slots__ = ( "channel", "loop", "proxy", "result_queue", "result_consumer_tag", "routes", "consumer_tags", "dlx_exchange", ) DLX_NAME = "rpc.dlx" DELIVERY_MODE = DeliveryMode.NOT_PERSISTENT __doc__ = """ Remote Procedure Call helper. Create an instance :: rpc = await RPC.create(channel) Registering python function :: # RPC instance passes only keyword arguments def multiply(*, x, y): return x * y await rpc.register("multiply", multiply) Call function through proxy :: assert await rpc.proxy.multiply(x=2, y=3) == 6 Call function explicit :: assert await rpc.call('multiply', dict(x=2, y=3)) == 6 """ def __init__(self, channel: Channel): self.channel = channel self.loop = self.channel.loop self.proxy = Proxy(self.call) self.result_queue = None self.futures = {} # type Dict[int, asyncio.Future] self.result_consumer_tag = None self.routes = {} self.queues = {} self.consumer_tags = {} self.dlx_exchange = None def __remove_future(self, future: asyncio.Future): log.debug("Remove done future %r", future) self.futures.pop(id(future), None) def create_future(self) -> asyncio.Future: future = self.loop.create_future() log.debug("Create future for RPC call") self.futures[id(future)] = future future.add_done_callback(self.__remove_future) return future @shield async def close(self): if self.result_queue is None: log.warning("RPC already closed") return log.debug("Cancelling listening %r", self.result_queue) await self.result_queue.cancel(self.result_consumer_tag) self.result_consumer_tag = None log.debug("Unbinding %r", self.result_queue) await self.result_queue.unbind( self.dlx_exchange, "", arguments={"From": self.result_queue.name, "x-match": "any"}, ) log.debug("Cancelling undone futures %r", self.futures) for future in self.futures.values(): if future.done(): continue future.set_exception(asyncio.CancelledError) log.debug("Deleting %r", self.result_queue) await self.result_queue.delete() self.result_queue = None @shield async def initialize(self, auto_delete=True, durable=False, **kwargs): if self.result_queue is not None: return self.result_queue = await self.channel.declare_queue( None, auto_delete=auto_delete, durable=durable, **kwargs ) self.dlx_exchange = await self.channel.declare_exchange( self.DLX_NAME, type=ExchangeType.HEADERS, auto_delete=True, ) await self.result_queue.bind( self.dlx_exchange, "", arguments={"From": self.result_queue.name, "x-match": "any"}, ) self.result_consumer_tag = await self.result_queue.consume( self.on_result_message, exclusive=True, no_ack=True, ) self.channel.add_close_callback(self.on_close) self.channel.add_on_return_callback( self.on_message_returned, weak=False, ) def on_close(self, exc=None): log.debug("Closing RPC futures because %r", exc) for future in self.futures.values(): if future.done(): continue future.set_exception(exc or Exception)
[docs] @classmethod async def create(cls, channel: Channel, **kwargs) -> "RPC": """ Creates a new instance of :class:`aio_pika.patterns.RPC`. You should use this method instead of :func:`__init__`, because :func:`create` returns coroutine and makes async initialize :param channel: initialized instance of :class:`aio_pika.Channel` :returns: :class:`RPC` """ rpc = cls(channel) await rpc.initialize(**kwargs) return rpc
def on_message_returned(self, sender: "RPC", message: ReturnedMessage): correlation_id = ( int(message.correlation_id) if message.correlation_id else None ) future = self.futures.pop(correlation_id, None) if not future or future.done(): log.warning("Unknown message was returned: %r", message) return future.set_exception(DeliveryError(message, None)) async def on_result_message(self, message: IncomingMessage): correlation_id = ( int(message.correlation_id) if message.correlation_id else None ) future = self.futures.pop(correlation_id, None) if future is None: log.warning("Unknown message: %r", message) return try: payload = self.deserialize(message.body) except Exception as e: log.error("Failed to deserialize response on message: %r", message) future.set_exception(e) return if message.type == RPCMessageTypes.result.value: future.set_result(payload) elif message.type == RPCMessageTypes.error.value: future.set_exception(payload) elif message.type == RPCMessageTypes.call.value: future.set_exception( asyncio.TimeoutError("Message timed-out", message), ) else: future.set_exception( RuntimeError("Unknown message type %r" % message.type), ) async def on_call_message( self, method_name: str, message: IncomingMessage ): if method_name not in self.routes: log.warning("Method %r not registered in %r", method_name, self) return try: payload = self.deserialize(message.body) func = self.routes[method_name] result = await self.execute(func, payload) result = self.serialize(result) message_type = RPCMessageTypes.result.value except Exception as e: result = self.serialize_exception(e) message_type = RPCMessageTypes.error.value if not message.reply_to: log.info( 'RPC message without "reply_to" header %r call result ' "will be lost", message, ) await message.ack() return result_message = Message( result, content_type=self.CONTENT_TYPE, correlation_id=message.correlation_id, delivery_mode=message.delivery_mode, timestamp=time.time(), type=message_type, ) try: await self.channel.default_exchange.publish( result_message, message.reply_to, mandatory=False, ) except Exception: log.exception("Failed to send reply %r", result_message) await message.reject(requeue=False) return if message_type == RPCMessageTypes.error.value: await message.ack() return await message.ack()
[docs] def serialize(self, data: Any) -> bytes: """ Serialize data to the bytes. Uses `pickle` by default. You should overlap this method when you want to change serializer :param data: Data which will be serialized :returns: bytes """ return super().serialize(data)
[docs] def deserialize(self, data: Any) -> bytes: """ Deserialize data from bytes. Uses `pickle` by default. You should overlap this method when you want to change serializer :param data: Data which will be deserialized :returns: :class:`Any` """ return super().deserialize(data)
[docs] def serialize_exception(self, exception: Exception) -> bytes: """ Serialize python exception to bytes :param exception: :class:`Exception` :return: bytes """ return pickle.dumps(exception)
[docs] async def execute(self, func: CallbackType, payload: P) -> R: """ Executes rpc call. Might be overlapped. """ return await func(**payload)
[docs] async def call( self, method_name, kwargs: Optional[Dict[Hashable, Any]] = None, *, expiration: Optional[int] = None, priority: int = 5, delivery_mode: DeliveryMode = DELIVERY_MODE ): """ Call remote method and awaiting result. :param method_name: Name of method :param kwargs: Methos kwargs :param expiration: If not `None` messages which staying in queue longer will be returned and :class:`asyncio.TimeoutError` will be raised. :param priority: Message priority :param delivery_mode: Call message delivery mode :raises asyncio.TimeoutError: when message expired :raises CancelledError: when called :func:`RPC.cancel` :raises RuntimeError: internal error """ future = self.create_future() message = Message( body=self.serialize(kwargs or {}), type=RPCMessageTypes.call.value, timestamp=time.time(), priority=priority, correlation_id=id(future), delivery_mode=delivery_mode, reply_to=self.result_queue.name, headers={"From": self.result_queue.name}, ) if expiration is not None: message.expiration = expiration log.debug("Publishing calls for %s(%r)", method_name, kwargs) await self.channel.default_exchange.publish( message, routing_key=method_name, mandatory=True, ) log.debug("Waiting RPC result for %s(%r)", method_name, kwargs) return await future
[docs] async def register(self, method_name, func: CallbackType, **kwargs): """ Method creates a queue with name which equal of `method_name` argument. Then subscribes this queue. :param method_name: Method name :param func: target function. Function **MUST** accept only keyword arguments. :param kwargs: arguments which will be passed to `queue_declare` :raises RuntimeError: Function already registered in this :class:`RPC` instance or method_name already used. """ arguments = kwargs.pop("arguments", {}) arguments.update({"x-dead-letter-exchange": self.DLX_NAME}) kwargs["arguments"] = arguments queue = await self.channel.declare_queue(method_name, **kwargs) if func in self.consumer_tags: raise RuntimeError("Function already registered") if method_name in self.routes: raise RuntimeError( "Method name already used for %r" % self.routes[method_name], ) self.consumer_tags[func] = await queue.consume( partial(self.on_call_message, method_name), ) self.routes[method_name] = awaitable(func) self.queues[func] = queue
[docs] async def unregister(self, func): """ Cancels subscription to the method-queue. :param func: Function """ if func not in self.consumer_tags: return consumer_tag = self.consumer_tags.pop(func) queue = self.queues.pop(func) await queue.cancel(consumer_tag) self.routes.pop(queue.name)
class JsonRPC(RPC): SERIALIZER = json CONTENT_TYPE = "application/json" def serialize(self, data: Any) -> bytes: return self.SERIALIZER.dumps(data, ensure_ascii=False, default=repr) def serialize_exception(self, exception: Exception) -> bytes: return self.serialize( { "error": { "type": exception.__class__.__name__, "message": repr(exception), "args": exception.args, }, }, )