Saturday, August 28, 2010

Case classes in Python

I'm trying to build a parser. The main function I've writing is called parse. It takes a str and returns a parse tree of token classes. I have a heap of tests that parse some text and then compare the parse tree against the an expected parse tree.

The way I started testing this looked a little like:

# token classes
class Add(Token):
    def __init__(self, left, right):
        self.left = left
        self.right = right

    def __repr__(self):
        return 'Add(' + repr(left) + ', ' + repr(right) + ')'

class Number(Token):
    def __init__(self, numberStr):
        self.numberStr = numberStr

    def __repr__(self):
        return 'Number(' + repr(numberStr) + ')'

# tests
class Tests(unittest.TestCase):
    def test_add(self):
        self.assertEqual(
            "Add(Number('1'), Number('1'))", 
            repr(parse('1+1'))

Making the token classes was a real chore. After a while, I had around 10 of them and writing them was increasingly tedious. It occurred to me that if I had somethinng like Scala Case Classes, I could change the code to something like:

# token classes
class Add(Token):
    def __init__(self, left, right):
        pass

class Number(Token):
    def __init__(self, numberStr):
        pass

# tests
class Tests(unittest.TestCase):
    def test_add(self):
        self.assertEqual(
            Add(Number('1'), Number('1'), 
            parse('1+1'))

This way the token classes basically write themselves and, because they have a handy eq method, the test don't need those messy string comparisons.

I did some research and found out that I could basically get the case class functionality (apart from the pattern matching aspect) using a custom Metaclass.

Here the result:

>>> from caseclasses import CaseMetaClass
>>>
>>> class MyCaseClass():
...     __metaclass__ = CaseMetaClass
...     def __init__(self, a, b):
...         pass
...
>>> instance = MyCaseClass(1, 'x')
>>> instance.a
1
>>> instance.b
'x'
>>> instance == MyCaseClass(1, 'y')
False
>>> instance == MyCaseClass(1, 'x')
True
>>> str(instance)
"MyCaseClass(1, 'x')"

This is how it works. Case classes are marked with the CaseMetaClass. For each argument in the __init__ method, a read-only property is generated with the same value as the argument when an instance is constructed. It should be obvious from the example how the generated __eq__ and __str__ methods work. I also added simple implementations of __ne__ and __hash__ to be consistnt with the __eq__.

Getting this working was much easier than I expected it to be. Here is the code:

import inspect

from decorator import decorator

class CaseMetaClass(type):
    def __new__(mcs, name, bases, dict):
        def noop(self):
            pass

        for meth in ('__eq__', '__ne__', '__hash__', '__str__'):
            if meth in dict:
                raise Exception('{} must not be defined on class.' % (meth))

        if '__init__' in dict:
            args, varargs, varkw, _ = inspect.getargspec(dict['__init__'])
            if varkw is not None:
                raise Exception("__init__ can't take **kwargs")
            args = args[1:]
        else:
            args = []
            varargs = None

        if args and varargs:
            raise Exception("Case class __init__ can't have both args (other than self) and *args")

        for arg in args + ([varargs] if varargs else []):
            if arg.startswith('_'):
                raise Exception("Case class attributes can't start with '_'.")
            dict[arg] = property(lambda self, arg=arg: getattr(self, '_' + arg))

        def _init(func, self, *init_args) :
            setattr(self, '_CaseMetaClass__args', init_args)
            if varargs:
                setattr(self, '_' + varargs, init_args)
            else:
                for (name, value) in zip(args, init_args):
                    setattr(self, '_' + name, value)
        dict['__init__'] = decorator(_init, dict.get('__init__', noop))

        def str(self):
            values = [repr(x) for x in getattr(self,'_CaseMetaClass__args')]
            return name + '(' + ', '.join(values) + ')'
        dict['__str__'] = str
        dict['__repr__'] = str

        def eq(self, other):
            if other is None:
                return False
            if type(self) is not type(other):
                return False
            return self._CaseMetaClass__args == other._CaseMetaClass__args
        dict['__eq__'] = eq

        dict['__ne__'] = lambda self, other: not (self == other)
        dict['__hash__'] = lambda self: hash(type(self)) ^ hash(self._CaseMetaClass__args)

        return type.__new__(mcs, name, bases, dict)