In Python, I often have the situation where I create a dictionary, and want to ensure that it is complete – it has an entry for every valid key.
Let’s say for my (currently hypothetical) automatic squirrel-deterring water gun system, I have a number of different states the water tank can be in, defined using an enum:
from enum import StrEnum class TankState(StrEnum): FULL = "FULL" HALF_FULL = "HALF_FULL" NEARLY_EMPTY = "NEARLY_EMPTY" EMPTY = "EMPTY"
In a separate bit of code, I define an RGB colour for each of these states, using a simple dict.
TANK_STATE_COLORS = { TankState.FULL: 0x00FF00, TankState.HALF_FULL: 0x28D728, TankState.NEARLY_EMPTY: 0xFF9900, TankState.EMPTY: 0xFF0000, }
This is deliberately distinct from my TankState
code and related definitions, because it relates to a different part of the project - the user interface. The UI concerns shouldn’t be mixed up with the core logic.
This dict is fine, and currently complete. But I’d like to ensure that if I add a new item to TankState
, I don’t forget to update the TANK_STATE_COLORS
dict.
With a growing ability to do static type checks in Python, some people have asked how we can ensure this using static type checks. The short answer is, we can’t (at least at the moment).
But the better question is “how can we (somehow) ensure we don’t forget?” It doesn’t have to be a static type check, as long as it’s very hard to forget, and if it preferably runs as early as possible.
Instead of shoe-horning everything into static type checks, let’s just make use of the fact that this is Python and we can write any code we want at module level. All we need to do is this:
TANK_STATE_COLORS = { # … } for val in TankState: assert val in TANK_STATE_COLORS, f"TANK_STATE_COLORS is missing an entry for {val}"
That’s it, that’s the whole technique. I’d argue that this is a pretty much optimal, Pythonic solution to the problem. No clever type tricks to debug later, just 2 lines of plain simple code, and it’s impossible to import your code until you fix the problem, which means you get the early checking you want. Plus you get exactly the error message you want, not some obscure compiler output, which is also really important.
It can also be extended if you want to do something more fancy (e.g. allow some values of the enum to be missing), and if it does get in your way, you can turn it off temporarily by just commenting out a couple of lines.
That’s not quite it
OK, in a project where I’m using this a lot, I did eventually get bored of this small bit of boilerplate. So, as a Pythonic extension of this Pythonic solution, I now do this:
TANK_STATE_COLORS: dict[TankState, int] = { TankState.FULL: 0x00FF00, TankState.HALF_FULL: 0x28D728, TankState.NEARLY_EMPTY: 0xFF9900, TankState.EMPTY: 0xFF0000, } assert_complete_enumerations_dict(TANK_STATE_COLORS)
Specifically, I’m adding:
a type hint on the constant
a call to a clever utility function that does just the right amount of Python magic.
This function needs to be “magical” because we want it to produce good error messages, like we had before. This means it needs to get hold of the name of the dict in the calling module, but functions don’t usually have access to that.
In addition, it wants to get hold of the type hint (although there would be other ways to infer it without a type hint, there are advantages this way), for which we also need the name.
The specific magic we need is:
the clever function needs to get hold of the module that called it
it then looks through the module dictionary to get the name of the object that has been passed in
then it can find type hints, and do the checking.
So, because you don’t want to write all that yourself, the code is below. It also supports:
having a tuple of
Enum
types as the keyallowing some items to be missing
-
using
Literal
as the key. So you can do things like this:
It’s got a ton of error checking, because once you get magical then you really don’t want to be debugging obscure messages.
Enjoy!
import inspect import itertools import sys import typing from collections.abc import Mapping, Sequence from enum import Enum from frozendict import frozendict def assert_complete_enumerations_dict[T](the_dict: Mapping[T, object], *, allowed_missing: Sequence[T] = ()): """ Magically assert that the dict in the calling module has a value for every item in an enumeration. The dict object must be bound to a name in the module. It must be type hinted, with the key being an Enum subclass, or Literal. The key may also be a tuple of Enum subclasses If you expect some values to be missing, pass them in `allowed_missing` """ assert isinstance(the_dict, Mapping), f"{the_dict!r} is not a dict or mapping, it is a {type(the_dict)}" frame_up = sys._getframe(1) # type: ignore[reportPrivateUsage] assert frame_up is not None module = inspect.getmodule(frame_up) assert module is not None, f"Couldn't get module for frame {frame_up}" msg_prefix = f"In module `{module.__name__}`," module_dict = frame_up.f_locals name: str | None = None # Find the object: names = [k for k, val in module_dict.items() if val is the_dict] assert names, f"{msg_prefix} there is no name for {the_dict}, please check" # Any name that has a type hint will do, there will usually be one. hints = typing.get_type_hints(module) hinted_names = [name for name in names if name in hints] assert ( hinted_names ), f"{msg_prefix} no type hints were found for {', '.join(names)}, they are needed to use assert_complete_enumerations_dict" name = hinted_names[0] hint = hints[name] origin = typing.get_origin(hint) assert origin is not None, f"{msg_prefix} type hint for {name} must supply arguments" assert origin in ( dict, typing.Mapping, Mapping, frozendict, ), f"{msg_prefix} type hint for {name} must be dict/frozendict/Mapping with arguments to use assert_complete_enumerations_dict, not {origin}" args = typing.get_args(hint) assert len(args) == 2, f"{msg_prefix} type hint for {name} must have two args" arg0, _ = args arg0_origin = typing.get_origin(arg0) if arg0_origin is tuple: # tuple of Enums enum_list = typing.get_args(arg0) for enum_cls in enum_list: assert issubclass( enum_cls, Enum ), f"{msg_prefix} type hint must be an Enum to use assert_complete_enumerations_dict, not {enum_cls}" items = list(itertools.product(*(list(enum_cls) for enum_cls in enum_list))) elif arg0_origin is typing.Literal: items = typing.get_args(arg0) else: assert issubclass( arg0, Enum ), f"{msg_prefix} type hint must be an Enum to use assert_complete_enumerations_dict, not {arg0}" items = list(arg0) for item in items: if item in allowed_missing: continue # This is the assert we actually want to do, everything else is just error checking: assert item in the_dict, f"{msg_prefix} {name} needs an entry for {item}"