01/24/2023 11:42 | Category: python

Tags: dataclassinheritance

python dataclass inheritance with default values is wonky

Recently I had an opportunity to utilize Python the @dataclass to handle some data that I wanted to behave similar to a Struct would in any other language.

The idea was I'd have a dataclass that held quite a few values that I needed to pass around and return to another consumer of the data, but I wanted to avoid a length __init__() with lots of arguments.

Problems

Below are a few of the snafoos that I ran into when working with dataclasses. While they're not dealbreakers, they were a bit interesting to work with and something to consider when creating some reusable (and understandable) data representations.

Instantiating without all the values

I immediately ran into a few issues with this, namely with instantiation of the class requiring all values to be present when it was created. Now, this could be remediated with default values - but I felt that defeated the purpose of the implementation. This was especially the case when a few of my fields were either calculated fields (relying on other data) or required some API calls to populate.

I found that the best course of action here was to handle some __post_init__() for those values that required special cases.

Example:

from dataclasses import dataclass, field

@dataclass
class Example

    var_one: int
    var_two: int
    sum_one_two: int = field(init=False)

    def __post_init__(self) -> None:
        self.sum_one_two = self.var_one + self.var_two

Inheritance with default values

Handling inheritance was an interesting issue, wherein the default values of the parent class are not considered during instantiation due to the ordering of how they're created.

The above StackOverflow post boils down to:

  • When you're using dataclass inheritance order matters first and foremost
  • Default values are likely best placed in separate "base classes" to avoid

instantiation errors

  • Default values should always be segmented away from any __post_init__ processing

that you may have in our dataclasses

Naive solution to work around this

I found that the idea of having multiple dataclasses as a base/parent to inherit from and separating default values to be a bit clunky.

To avoid this, I instead init the parent class first during the __post_init__ then handle the remainder of the non-default values for the child class.

In the below example we can see that our parent Example is initialized first, has it's __post_init__ run, then we handle the non-default/calculated fields.

If we were to avoid the super().__post_init__() we'd receive a TypeError: non-default argument 'sum_one_two_three' follows default argument error that indicates an instantiation problem.

from dataclasses import dataclass, field

@dataclass
class Example

    var_one: int = 1
    var_two: int = 1
    sum_one_two: int = field(init=False)

    def __post_init__(self) -> None:
        self.sum_one_two = self.var_one + self.var_two


@dataclass
class ExampleChild(Example):

    var_three: int = 1
    sum_one_two_three: int = field(init=False)

    def __post_init__(self) -> None:
        super().__post_init__()

        self.sum_one_two_three = self.sum_one_two + self.var_three

Dictionary representations with inheritance

When working with the dictionary representations of this information on the consumer side of these dataclasses we may run into issues getting the inheritance and various other things sorted out.

In some cases we may need to remove keys or need to return a representation that a simple asdict(MyDataclass) does not handle well.

This is particularly true with inheritance, where I noticed that I received a "recursive limit reached" when attempting to return an asdict() representation.

To fix this, I used a workaround that returns the dictionary representation using fields to create a key-value pairing.

from dataclasses import dataclass, field, fields
from typing import Union

# Example set of keys to remove before returning to the consumer
REMOVE_THESE_KEYS = {'_internal_key_id_one', '_internal_key_id_two'}


@dataclass
class Example:
    pass


class ExampleChild(Example):
    pass


def get_dict_representation(datarepresentation: Union[Example, ExampleChild]) -> Dict:
    return dict(
        (field.name, getattr(datarepresentation, field.name))
        for field in fields(datarepresentation)
        if field.name not in REMOVE_THESE_KEYS
    )