6 min read

Celery Preserializers: A low-friction path to Pydantic support

Cover image showing the Dosu logo and the title of the article

Author: Michael Mangus | Engineer at Dosu

At Dosu, we use Celery for asynchronous tasks and Pydantic for data validation. If you’re like us, you were probably excited to see “Pydantic support” in the patch notes for Celery 5.5.0, then disappointed when you read the disclaimer that “you still have [to] serialize arguments yourself when invoking a task.

If you want to move fast without breaking things, the lowest-friction way to do something should also be the correct way to do it. If it’s hard to make a mistake, it’s easy to ship with confidence. Celery’s new Pydantic support has many sources of friction that invite you to make mistakes:

  1. You need to remember to set pydantic=True in the task decorator.
  2. You need to manually model_dump instances to their dictionary representation when passing them into delay or apply_async.
  3. Your tasks return dictionaries, not model instances, so you need to manually deserialize the output.
  4. Because of issues 2 and 3, your type annotations for tasks with pydantic=True are misleading. From the caller’s perspective, the arguments and return types annotated as Pydantic models are actually of type dict.

In this post, we’ll explain how we eliminated all of these sources of friction with an alternative implementation which even works in older versions of Celery. In our codebase, if you want to write a Celery task that uses Pydantic models, you just do it. There are no other steps, and it works like you’d expect it to work.

I am not sure if this technique is valid for every possible configuration of Celery — which might be why it wasn’t implemented this way in the library — but for our fairly-typical use case it’s been working hassle-free in production for the past few months.

Celery depends on serialization

A Celery deployment consists of a message broker which tracks work that needs to be done, a worker pool which actually does it, and a result backend which stores the output. You wrap your Python function f with a @task decorator to make it runnable on the worker pool, then use f.delay(...) to put a message on the broker asking a worker to run that function with the given parameters.

One limitation of this architecture is that you need to serialize the task parameters to put them on the message broker. By default, all the parameters to a Celery task must be serializable as JSON. However, the built-in Python JSON encoder only supports a small set of types, and Pydantic models aren't among them.

Kombu is Celery’s homegrown serialization library. If you play around with Celery a bit, you’ll notice that it’s not exactly limited to types that are serializable by json.dumps. For example, a datetime is not serializable by default in Python:

>>> import json
>>> from datetime import datetime
>>> json.dumps(datetime.now())
TypeError: Object of type datetime is not JSON serializable

Even so, this Celery task works just fine out of the box:

@app.task()
def day(dt: datetime) -> int:
    return dt.day

----------    
>>> day.delay(datetime.now()).get()
15

How can that be?

Kombu’s register_type functionality

As it turns out, Kombu keeps a registry of custom encoders and decoders for certain types, which it pre-populates with a few common types. These encoders and decoders are simple ways of translating a non-serializable type into a serializable one. For example, a datetime can be encoded as str using datetime.isoformat, then decoded from that str using datetime.fromisoformat.

By calling the function kombu.utils.json.register_type, you can define a custom encoder and decoder for a type. Once your type is registered with Kombu, it works seamlessly with Celery tasks.

That’s exactly how we wanted Pydantic support to work. All we need to do is register an encoder to pack up a BaseModel into a serializable format and a decoder to unpack it back into the appropriate type on the worker.

Since Pydantic is built for serialization, model_dump does most of the work for us. However, to reconstruct an instance, you also need to know which model class to instantiate. The class can be specified by two strings: __module__ (the dotted path to import) and __qualname__ (the name within that module). You’ll recognize these as the components of the dotted paths you often see in the string representations of objects. Here, the __module__ is "x.y.z" and the __qualname__ is "Example":

>>> from x.y.z import Example
>>> str(Example)
"<class 'x.y.z.Example'>"

We can use this information to import the type when deserializing:

from importlib import import_module
from types import ModuleType

def load_from_path(module: str, qualname: str) -> Any:
    """
    Given a dotted path to a module and the qualified name of a 
     member of the module, import the module and return the 
     named member.
    """
    m = import_module(module)
    o: type | ModuleType = m
    # this handles e.g. a class nested in a class
    for a in qualname.split("."):
        o = getattr(o, a)
    return o

Now we’ve got the pieces we need. Let’s put it all together.

Implementing a Preserializer for BaseModel

At Dosu, we encapsulate the Kombu encoding and decoding operations in a type we call a Preserializer. It is so-named because it doesn’t actually produce any JSON itself, just JSON-serializable objects. The interface looks like this:

from typing import Protocol

class Preserializer(Protocol):
    """
    A Preserializer can be used to `pack` a non-serializable  
     object into a serializable one, then `unpack` it again.
    """

    @classmethod
    def compatible_with(cls, type_: type) -> Literal[True]:
        """
        If the given type is compatible with this strategy, return 
         `True`. If not, raise an exception explaining why not.
        """
        
    @classmethod
    def pack(cls, obj: Any) -> Any:
        """
        Pack the given object into a JSON-serializable object.
        """

    @classmethod
    def unpack(cls, data: Any) -> Any:
        """
        Unpack the serializable object back into an instance of
         its original type.
        """ 

Here’s a Preserializer that handles Pydantic models:

from typing import TypedDict

class PackedModel(TypedDict):
    module: str
    qualname: str
    dump: dict[str, Any]


class PydanticModelDump:

    @classmethod
    def compatible_with(cls, type_: type) -> Literal[True]:
        if not issubclass(type_, BaseModel):
            raise TypeError(
                "PydanticModelDump requires a type that inherits "
                "from BaseModel"
            )
        return True

    @classmethod
    def pack(cls, obj: BaseModel) -> PackedModel:
        return {
            "module": obj.__class__.__module__,
            "qualname": obj.__class__.__qualname__,
            "dump": obj.model_dump(),
        }

    @classmethod
    def unpack(cls, data: PackedModel) -> BaseModel:
        t = load_from_path(data["module"], data["qualname"])
        if not (isinstance(t, type) and issubclass(t, BaseModel)):
            raise TypeError(
                f"Cannot unpack {t}: not a Pydantic model"
            )
        # NB: our actual implementation has some special handling 
        #  for custom RootModel subclasses that we use, but I've 
        #  omited them here for simplicity - this only handles 
        #  basic models
        return t(**data["dump"])

In order to assign a Preserializer to a type in the Kombu registry, we have a decorator factory, @register_preserializer(...)

from kombu.utils.json import register_type

class register_preserializer:
	"""
	Decorator factory that registers a Preserializer for the 
     decorated type in the Kombu JSON type registry.
	"""
	def __init__(self, preserializer: Preserializer):
	    self.preserializer = preserializer
	
	def __call__(self, type_: type[T]) -> type[T]:
	    if (
            "<locals>" in type_.__qualname__ 
            or "__main__" in type_.__module__
        ):
	        raise TypeError(
	            "You cannot register preserializers on objects "
	            "that aren't directly accessible at import time."
	        )
	
	    try:
	        self.preserializer.compatible_with(type_)
	    except Exception as e:
	        raise TypeError(
	            f"{type_} is not compatible with "
                f"{self.preserializer}: {e}"
	        ) from e
	
	    register_type(
	        type_,
	        f"{type_.__module__}.{type_.__qualname__}",
	        encoder=self.preserializer.pack,
	        decoder=self.preserializer.unpack,
	    )
	    return type_

Finally, we register PydanticModelDump to handle every instance of BaseModel like so:

from pydantic import BaseModel

register_preserializer(PydanticModelDump)(BaseModel)

In order for this to work, you need to make sure that the type gets registered both when producing tasks and when consuming them on the worker. We accomplish that by executing the register_preserializer call at the top level of the module where we define our Celery entrypoint, app = Celery(...). That module is always imported by workers, and also everywhere we use the @app.task decorator, so the registration code runs when we need it to.

Solutions compared

With our implementation, Pydantic models are usable without any special code when defining or calling your task. Let’s compare our solution with the official approach side by side to see how they differ.

In the Celery documentation, they give this example:

from pydantic import BaseModel

class ArgModel(BaseModel):
    value: int

class ReturnModel(BaseModel):
    value: str

@app.task(pydantic=True)  # can't forget this
def x(arg: ArgModel) -> ReturnModel:  # types only right inside x
	return ReturnModel(value=f"example: {arg.value}")

----------
>>> result = x.delay({'value': 1})  # must pass a dict
>>> result.get(timeout=1)
{'value': 'example: 1'}  # get back a dict

On the other hand, this is how our solution looks:

# using the same models as above 

@app.task()  # no special param
def x(arg: ArgModel) -> ReturnModel:  # types right everywhere
    return ReturnModel(value=f"example: {arg.value}")

----------
>>> result = x.delay(ArgModel(value=1))  # pass an instance
>>> result.get(timeout=1)
ReturnModel(value="example: 1")  # get back an instance

There’s nothing special to think about, so there’s nothing you can forget or do wrong. Exactly what we like.

What next?

Obviously you can extend this concept with more serialization strategies for other types. We currently use 6 varieties of Preserializer, and most of those are capable of handling many different concrete types.

This is the first installment of a two-part series. Next week, we’ll cover how we built a Preserializer to facilitate better testing with dependency injection.

Tired of writing, maintaining, and sharing documentation? Let Dosu do it for you. Try it out for free at https://app.dosu.dev/. If you’d like to discuss this article, give general feedback or ask questions, feel free to join our Discord.