Better Agent Testing with Distributed Dependency Injection

Author: Michael Mangus | Engineer at Dosu
Testing agents is hard. Dosu depends on many external services: a relational database, an in-memory cache, multiple LLMs, and a growing number of third party integrations like GitHub and Slack. The runtime behavior of our code is contingent on those services, but actually interfacing with all of them throughout our test suite would be slow to write and slow to run. On top of that, we run our tests a lot when developing, and genuine LLM output is expensive and flaky.
When you send Dosu a message, an orchestrator process starts a distributed workflow using asynchronous Celery tasks. As you’d expect, most of those tasks need to use an LLM client, among other dependencies.
As part of our test suite, we want to be able to inject mock dependencies that simulate specific 3rd party responses, then make assertions about how Dosu behaves. For instance, we want to simulate an LLM output that classifies a message as spam, then assert that the agent skips the remaining tasks in its workflow and does not respond.
This lets us separate two distinct concerns: the behavior of the agent when dealing with spam is covered by a single deterministic test, while the actual quality of our spam detection prompting can be evaluated independently over a large sample using statistical methods.
Dependency injection also makes it easy to adapt to different contexts. We can send messages with a SlackPublisher
on Slack, or a GitHubPublisher
on GitHub. By injecting different LLM clients, we can load balance between regions or experiment with alternative models.
Every time the orchestrator starts a task, it needs to include the dependencies as parameters. Since the agent’s workload is distributed, ultimately the dependencies are going to have to cross a network boundary, so we have to make them serializable — but most user-defined types aren’t JSON serializable by default. In this post, we’ll explain how we implemented a solution that covers most of our dependencies with a single code path.
Registering Preserializers
In last week’s post, we introduced the concept of a Preserializer
, which is our term for a class that is able to translate between a non-serializable object and a serializable representation. We showed how you can register a Preserializer
for a type using Kombu's register_type
function, then seamlessly use instances of that type as parameters when calling an asynchronous task.
With these basic building blocks in mind, let’s consider what it’s going to take to build a Preserializer
for our dependencies.
Ultimately, serializing an object means representing all the relevant information about how to make that object again from scratch in another execution environment. Most of our dependencies look approximately like this:
class SomeDependency:
def __init__(self, endpoint: str, credentials: str):
self.endpoint = endpoint
self.credentials = credentials
@cached_property
def client(self) -> Client:
return Client(self.endpoint, self.credentials)
def do_thing_with_client(self) -> None:
self.client.do_thing()
Unfortunately, out of the box, this does not work:
@app.task()
def use_dep(dep: SomeDependency) -> tuple[str, str]:
return dep.endpoint, dep.credentials
----------
>>> d = SomeDependency("example.com", "password")
>>> use_dep.delay(d).get()
EncodeError: Object of type SomeDependency is not JSON serializable
Making dependencies reconstructable
For a type like SomeDependency
, the only relevant pieces of state are the endpoint
and credentials
passed to the constructor. The operations we do with an instance after we’ve constructed it don’t mutate anything important about its state. So, for these relatively simple, just-about-stateless objects, you can make a new equivalent object by knowing two pieces of information:
- what type of object it is
- the
__init__
parameters that were used to make it.
Step 1 is very straightforward — we solved it in our last post. The type itself can be specified by two strings: __module__
(the dotted path to import) and __qualname__
(the name within that module).
That brings us to step 2: knowing the parameters that were passed into __init__
when constructing the object we’re serializing, so we can use those parameters again to make our new instance when deserializing.
Of course, you could build a bespoke serialization strategy for every class you need to pass around, specifying all the relevant attributes of that class by hand and packing them up into a dictionary representation. In complicated cases — for example, a more stateful example where method calls mutate the object — that’s likely the best option. However, we’d like to handle all of the simple cases in one fell swoop, since less code means less surface area for maintenance.
Our solution is what we refer to as a reconstructable type. When a class definition is decorated with @reconstructable
, we keep a global registry that maps every instance of the class to the __init__
parameters that were used to construct it. While this does use some memory, we store the registry as a WeakKeyDictionary
so it doesn’t interfere with garbage collection. We tend to make the instances of our dependencies once in a composition root, then use the same instance many times over, so in practice we don’t have very many distinct reconstructable
instances at a time and the memory overhead is low.
By relying on the reconstructable
instance registry, we can implement a single Preserializer
that’s able to handle many different classes without having to care about the details of their individual __init__
signatures or other attributes. When we want to serialize an instance, we can look up the appropriate __init__
parameters in the registry and send them alongside the type itself, very much like we did with Pydantic models.
Like any good magic spell, there is a fair bit of janky-looking stuff going on in here and your mileage may vary, but we’ve been running this code error-free in production for several months now:
from weakref import WeakKeyDictionary
constructor_registry: WeakKeyDictionary[
object, tuple[tuple[Any, ...], dict[str, Any]]
] = WeakKeyDictionary()
reconstructable_types: set[type] = set()
def reconstructable(type_: type[T]) -> type[T]:
"""
Class decorator for types that can be reconstructed from
their init params. Tracks the init params for every instance
of the type in the weak-ref'ed `constructor_registry`
dictionary so that you can re-construct those instances
later. Automatically registers the type with the `Reconstruct`
preserializer so that it can be sent as a parameter to a
Celery task. If the instance gets `copy` or `deepcopy`ed,
the new copy is also added to the constructor registry.
"""
original_init = type_.__init__
def new_init(
instance: object, *args: Any, **kwargs: Any
) -> None:
original_init(instance, *args, **kwargs)
constructor_registry[instance] = (args, kwargs)
type_.__init__ = new_init # type: ignore[assignment,method-assign]
# when an object gets `copy`ed or `deepcopy`ed, we need to
# record a new entry in the constructor registry for the copy.
# if the object already has a custom `__copy__` or
# `__deepcopy__` implementation, we'll wrap that.
original_copy = getattr(type_, "__copy__", None)
def new_copy(instance: T) -> T:
args, kwargs = constructor_registry[instance]
new_instance = (
original_copy(instance)
if original_copy
else instance.__class__(*args, **kwargs)
)
constructor_registry[new_instance] = (args, kwargs)
return new_instance
# mypy allows this, but not `type_.__copy__ = ...`
setattr(type_, "__copy__", new_copy)
original_deepcopy = getattr(type_, "__deepcopy__", None)
def new_deepcopy(instance: T, memo: dict[Any, Any]) -> T:
# just to be extra safe, we also deepcopy the args
# and kwargs
args, kwargs = constructor_registry[instance]
args_copy = deepcopy(args)
kwargs_copy = deepcopy(kwargs)
new_instance = (
original_deepcopy(instance, memo)
if original_deepcopy
else instance.__class__(*args_copy, **kwargs_copy)
)
constructor_registry[new_instance] = args_copy, kwargs_copy
return new_instance
setattr(type_, "__deepcopy__", new_deepcopy)
# register the type as reconstructable so we can use the
# `Reconstruct` preserializer
reconstructable_types.add(type_)
register_preserializer(Reconstruct)(type_)
return type_
Let’s look past the arcana of managing copy
ing; the real meat of this code comes very early on. Whenever we decorate a class @reconstructable
, we put a wrapper around its __init__
method. Our wrapper puts a new entry in the global constructor_registry
, where the key is a weakref
to the instance that was just made and the value is the *args
and **kwargs
that were passed. We can use that registry to build this Preserializer
:
class PackedConstructor(TypedDict):
module: str
qualname: str
init_args: tuple[Any, ...]
init_kwargs: dict[str, Any]
class Reconstruct:
@classmethod
def compatible_with(cls, type_: type) -> Literal[True]:
if type_ not in reconstructable_types:
raise TypeError(
"Use the @reconstructable decorator to register "
"your type."
)
return True
@classmethod
def pack(cls, obj: object) -> PackedConstructor:
try:
init_args, init_kwargs = constructor_registry[obj]
except KeyError as e:
raise RuntimeError(
f"{obj} missing from reconstructable registry"
) from e
return {
"module": obj.__class__.__module__,
"qualname": obj.__class__.__qualname__,
"init_args": init_args,
"init_kwargs": init_kwargs,
}
@classmethod
def unpack(cls, data: PackedConstructor) -> object:
m = import_module(data["module"])
o: type | ModuleType = m
for a in data["qualname"].split("."):
o = getattr(o, a)
if not isinstance(o, type):
raise TypeError(
f"{data['module']}.{data['qualname']} "
"is not a constructable type"
)
return o(*data["init_args"], **data["init_kwargs"])
And voilà! This now works:
@reconstructable
class SomeDependency:
def __init__(self, endpoint: str, credentials: str):
self.endpoint = endpoint
self.credentials = credentials
@cached_property
def client(self) -> Client:
return Client(self.endpoint, self.credentials)
def do_thing_with_client(self) -> None:
self.client.do_thing()
@app.task()
def use_dep(dep: SomeDependency) -> tuple[str, str]:
return dep.endpoint, dep.credentials
----------
>>> d = SomeDependency("example.com", "password")
>>> use_dep.delay(d).get()
["example.com", "password"] # tuples automatically become lists
Caveats
The keen reader probably realizes that there is a small gotcha: all the parameters of our dependencies also need to be serializable, and so on all the way down. Fortunately, Kombu’s serialization strategies work recursively without any special effort. For instance, our LLM
class depends on an injected Cache
, so we just made both of those types @reconstructable
.
There is another, more subtle limitation as well. In Celery, serializing the task parameters to send to the broker is only the first layer. When a worker node consumes your task off the broker, it deserializes the parameters from JSON, but then uses pickle
under the hood to actually send those parameters to a local child process for execution (at least in the default prefork
mode).
So, in addition to having a JSON serialization layer, we also need to make sure that everything we unpack on the worker is pickle
able. In practice, this constraint hasn’t been that big of a deal for us. So far, the only issue we’ve seen is that we need to defer opening any upstream connections that hold a thread lock until after __init__
in reconstructable types (hence the cached_property
for the client in our example class).
Other applications
By making our dependencies serializable and injecting them, we’ve made it easy to write realistic tests that run our normal distributed workflows with minimal differences from actual production codepaths. It’s hard to overstate the value of having fast, trustworthy, deterministic tests for high-level agent behaviors like when Dosu will reply, what we write to our database during the agent workflow, what state gets passed around between tasks, etc. We can quickly catch basic problems in the agent workflow without ever actually sending any requests upstream to LLM providers.
I hope we’ve also shown how the concept of registering Preserializer
s has broad utility for projects using Celery. In addition to the strategies we’ve shared in these posts, here are some other examples we’re using in production to make serialization easy:
Stringify
, which can be registered on any object that can be trivially reconstructed from itsstr
representation (e.g. Pydantic’sUrl
type, which interestingly is not a subtype ofBaseModel
)TypePath
andUnionTypeArgs
, which we register forUnion
types and just plain oldtype
s themselves (we use those for runtime guardrails on the return types of generic Celery tasks)CeleryResultTuple
, which helps with tasks that return other tasks'GroupResult
andAsyncResult
instances
In each case, the basic idea is the same: represent the type itself using __module__
and __qualname__
, then add whatever extra information you need to construct your object.
Got a great idea for how to serialize a tricky type? Tell us about it on Discord.
Tired of writing, maintaining, and sharing documentation? Let Dosu do it for you. Try it out for free at https://app.dosu.dev/.