10 min read

Better Agent Testing with Distributed Dependency Injection

Better Agent Testing with Distributed Dependency Injection

Author: Michael Mangus | Engineer at Dosu

Testing agents is hard. Dosu depends on many external services: a relational database, an in-memory cache, multiple LLMs, and a growing number of third party integrations like GitHub and Slack. The runtime behavior of our code is contingent on those services, but actually interfacing with all of them throughout our test suite would be slow to write and slow to run. On top of that, we run our tests a lot when developing, and genuine LLM output is expensive and flaky.

When you send Dosu a message, an orchestrator process starts a distributed workflow using asynchronous Celery tasks. As you’d expect, most of those tasks need to use an LLM client, among other dependencies.

As part of our test suite, we want to be able to inject mock dependencies that simulate specific 3rd party responses, then make assertions about how Dosu behaves. For instance, we want to simulate an LLM output that classifies a message as spam, then assert that the agent skips the remaining tasks in its workflow and does not respond.

This lets us separate two distinct concerns: the behavior of the agent when dealing with spam is covered by a single deterministic test, while the actual quality of our spam detection prompting can be evaluated independently over a large sample using statistical methods.

Dependency injection also makes it easy to adapt to different contexts. We can send messages with a SlackPublisher on Slack, or a GitHubPublisher on GitHub. By injecting different LLM clients, we can load balance between regions or experiment with alternative models.

Every time the orchestrator starts a task, it needs to include the dependencies as parameters. Since the agent’s workload is distributed, ultimately the dependencies are going to have to cross a network boundary, so we have to make them serializable — but most user-defined types aren’t JSON serializable by default. In this post, we’ll explain how we implemented a solution that covers most of our dependencies with a single code path.

Registering Preserializers

In last week’s post, we introduced the concept of a Preserializer, which is our term for a class that is able to translate between a non-serializable object and a serializable representation. We showed how you can register a Preserializer for a type using Kombu's register_type function, then seamlessly use instances of that type as parameters when calling an asynchronous task. 

With these basic building blocks in mind, let’s consider what it’s going to take to build a Preserializer for our dependencies.

Ultimately, serializing an object means representing all the relevant information about how to make that object again from scratch in another execution environment. Most of our dependencies look approximately like this:

class SomeDependency:
    def __init__(self, endpoint: str, credentials: str):
	   self.endpoint = endpoint
	   self.credentials = credentials
	   
    @cached_property
    def client(self) -> Client:
	   return Client(self.endpoint, self.credentials)
	   
    def do_thing_with_client(self) -> None:
	   self.client.do_thing()

Unfortunately, out of the box, this does not work:

@app.task()
def use_dep(dep: SomeDependency) -> tuple[str, str]:
    return dep.endpoint, dep.credentials

----------
>>> d = SomeDependency("example.com", "password")
>>> use_dep.delay(d).get()
EncodeError: Object of type SomeDependency is not JSON serializable

Making dependencies reconstructable

For a type like SomeDependency, the only relevant pieces of state are the endpoint and credentials passed to the constructor. The operations we do with an instance after we’ve constructed it don’t mutate anything important about its state. So, for these relatively simple, just-about-stateless objects, you can make a new equivalent object by knowing two pieces of information:

  1. what type of object it is
  2. the __init__ parameters that were used to make it.

Step 1 is very straightforward — we solved it in our last post. The type itself can be specified by two strings: __module__ (the dotted path to import) and __qualname__ (the name within that module).

That brings us to step 2: knowing the parameters that were passed into __init__ when constructing the object we’re serializing, so we can use those parameters again to make our new instance when deserializing.

Of course, you could build a bespoke serialization strategy for every class you need to pass around, specifying all the relevant attributes of that class by hand and packing them up into a dictionary representation. In complicated cases — for example, a more stateful example where method calls mutate the object — that’s likely the best option. However, we’d like to handle all of the simple cases in one fell swoop, since less code means less surface area for maintenance.

Our solution is what we refer to as a reconstructable type. When a class definition is decorated with @reconstructable, we keep a global registry that maps every instance of the class to the __init__ parameters that were used to construct it. While this does use some memory, we store the registry as a WeakKeyDictionary so it doesn’t interfere with garbage collection. We tend to make the instances of our dependencies once in a composition root, then use the same instance many times over, so in practice we don’t have very many distinct reconstructable instances at a time and the memory overhead is low.

By relying on the reconstructable instance registry, we can implement a single Preserializer that’s able to handle many different classes without having to care about the details of their individual __init__ signatures or other attributes. When we want to serialize an instance, we can look up the appropriate __init__ parameters in the registry and send them alongside the type itself, very much like we did with Pydantic models.

Like any good magic spell, there is a fair bit of janky-looking stuff going on in here and your mileage may vary, but we’ve been running this code error-free in production for several months now:

from weakref import WeakKeyDictionary

constructor_registry: WeakKeyDictionary[
    object, tuple[tuple[Any, ...], dict[str, Any]]
] = WeakKeyDictionary()

reconstructable_types: set[type] = set()


def reconstructable(type_: type[T]) -> type[T]:
    """
    Class decorator for types that can be reconstructed from 
     their init params. Tracks the init params for every instance
     of the type in the weak-ref'ed `constructor_registry` 
     dictionary so that you can re-construct those instances 
     later. Automatically registers the type with the `Reconstruct`
     preserializer so that it can be sent as a parameter to a 
     Celery task. If the instance gets `copy` or `deepcopy`ed, 
     the new copy is also added to the constructor registry.
    """
    original_init = type_.__init__
    def new_init(
        instance: object, *args: Any, **kwargs: Any
    ) -> None:
        original_init(instance, *args, **kwargs)
        constructor_registry[instance] = (args, kwargs)
        
    type_.__init__ = new_init  # type: ignore[assignment,method-assign]
    
    # when an object gets `copy`ed or `deepcopy`ed, we need to 
    #  record a new entry in the constructor registry for the copy. 
    #  if the object already has a custom `__copy__` or 
    #  `__deepcopy__` implementation, we'll wrap that.
    original_copy = getattr(type_, "__copy__", None)
    def new_copy(instance: T) -> T:
        args, kwargs = constructor_registry[instance]
        new_instance = (
            original_copy(instance)
            if original_copy
            else instance.__class__(*args, **kwargs)
        )
        constructor_registry[new_instance] = (args, kwargs)
        return new_instance
    # mypy allows this, but not `type_.__copy__ = ...`
    setattr(type_, "__copy__", new_copy)
    
    original_deepcopy = getattr(type_, "__deepcopy__", None)
    def new_deepcopy(instance: T, memo: dict[Any, Any]) -> T:
        # just to be extra safe, we also deepcopy the args 
        #  and kwargs
        args, kwargs = constructor_registry[instance]
        args_copy = deepcopy(args)
        kwargs_copy = deepcopy(kwargs)
        new_instance = (
            original_deepcopy(instance, memo)
            if original_deepcopy
            else instance.__class__(*args_copy, **kwargs_copy)
        )
        constructor_registry[new_instance] = args_copy, kwargs_copy
        return new_instance
    setattr(type_, "__deepcopy__", new_deepcopy)
    # register the type as reconstructable so we can use the 
    #  `Reconstruct` preserializer
    reconstructable_types.add(type_)
    register_preserializer(Reconstruct)(type_)
    return type_

Let’s look past the arcana of managing copying; the real meat of this code comes very early on. Whenever we decorate a class @reconstructable, we put a wrapper around its __init__ method. Our wrapper puts a new entry in the global constructor_registry, where the key is a weakref to the instance that was just made and the value is the *args and **kwargs that were passed. We can use that registry to build this Preserializer:

class PackedConstructor(TypedDict):
    module: str
    qualname: str
    init_args: tuple[Any, ...]
    init_kwargs: dict[str, Any]
    
class Reconstruct:
    @classmethod
    def compatible_with(cls, type_: type) -> Literal[True]:
        if type_ not in reconstructable_types:
            raise TypeError(
                "Use the @reconstructable decorator to register "
                "your type."
            )
        return True
        
    @classmethod
    def pack(cls, obj: object) -> PackedConstructor:
        try:
            init_args, init_kwargs = constructor_registry[obj]
        except KeyError as e:
            raise RuntimeError(
                f"{obj} missing from reconstructable registry"
            ) from e
            
        return {
            "module": obj.__class__.__module__,
            "qualname": obj.__class__.__qualname__,
            "init_args": init_args,
            "init_kwargs": init_kwargs,
        }
        
    @classmethod
    def unpack(cls, data: PackedConstructor) -> object:
        m = import_module(data["module"])
        o: type | ModuleType = m
        for a in data["qualname"].split("."):
            o = getattr(o, a)
            
        if not isinstance(o, type):
            raise TypeError(
                f"{data['module']}.{data['qualname']} "
                "is not a constructable type"
            )
            
        return o(*data["init_args"], **data["init_kwargs"])

And voilà! This now works:

@reconstructable
class SomeDependency:
    def __init__(self, endpoint: str, credentials: str):
        self.endpoint = endpoint
        self.credentials = credentials
    
    @cached_property
    def client(self) -> Client:
        return Client(self.endpoint, self.credentials)
   
    def do_thing_with_client(self) -> None:
        self.client.do_thing()
  
@app.task()
def use_dep(dep: SomeDependency) -> tuple[str, str]:
    return dep.endpoint, dep.credentials

----------
>>> d = SomeDependency("example.com", "password")
>>> use_dep.delay(d).get()
["example.com", "password"]  # tuples automatically become lists

Caveats

The keen reader probably realizes that there is a small gotcha: all the parameters of our dependencies also need to be serializable, and so on all the way down. Fortunately, Kombu’s serialization strategies work recursively without any special effort. For instance, our LLM class depends on an injected Cache, so we just made both of those types @reconstructable.

There is another, more subtle limitation as well. In Celery, serializing the task parameters to send to the broker is only the first layer. When a worker node consumes your task off the broker, it deserializes the parameters from JSON, but then uses pickle under the hood to actually send those parameters to a local child process for execution (at least in the default prefork mode).

So, in addition to having a JSON serialization layer, we also need to make sure that everything we unpack on the worker is pickleable. In practice, this constraint hasn’t been that big of a deal for us. So far, the only issue we’ve seen is that we need to defer opening any upstream connections that hold a thread lock until after __init__ in reconstructable types (hence the cached_property for the client in our example class).

Other applications

By making our dependencies serializable and injecting them, we’ve made it easy to write realistic tests that run our normal distributed workflows with minimal differences from actual production codepaths. It’s hard to overstate the value of having fast, trustworthy, deterministic tests for high-level agent behaviors like when Dosu will reply, what we write to our database during the agent workflow, what state gets passed around between tasks, etc. We can quickly catch basic problems in the agent workflow without ever actually sending any requests upstream to LLM providers.

I hope we’ve also shown how the concept of registering Preserializers has broad utility for projects using Celery. In addition to the strategies we’ve shared in these posts, here are some other examples we’re using in production to make serialization easy: 

  • Stringify, which can be registered on any object that can be trivially reconstructed from its str representation (e.g. Pydantic’s Url type, which interestingly is not a subtype of BaseModel)
  • TypePath and UnionTypeArgs, which we register for Union types and just plain old types themselves (we use those for runtime guardrails on the return types of generic Celery tasks)
  • CeleryResultTuple, which helps with tasks that return other tasks' GroupResult and AsyncResult instances

In each case, the basic idea is the same: represent the type itself using __module__ and __qualname__, then add whatever extra information you need to construct your object.

Got a great idea for how to serialize a tricky type? Tell us about it on Discord

Tired of writing, maintaining, and sharing documentation? Let Dosu do it for you. Try it out for free at https://app.dosu.dev/.