目录

classmethod构建多态通用的类

概述

多态是指继承体系中的多个类都可以以各自独有的方式来实现某个方法,这些类都满足相同的接口或者继承自相同的抽象类,但却各有功能。

代码

下面是一个类继承的实例。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class InputData(object):
    def read(self):
        raise NotImplementedError


class PathInputData(InputData):
    def __init__(self, path):
        super(PathInputData, self).__init__()
        self.path = path

    def read(self):
        return open(self.path).read()


class Worker(object):
    def __init__(self, input_data):
        self.input_data = input_data
        self.result = None

    def map(self):
        raise NotImplementedError

    def reduce(self, other):
        raise NotImplementedError


class LineCountWorker(Worker):
    def map(self):
        data = self.input_data.read()
        self.result = data.count('\n')

    def reduce(self, other):
        self.result += other.result

InputData 为抽象类,其 read 方法需要由继承他的 PathInputData 子类来实现。然后是一个关于工作线程的一套抽象接口,Worker 类作为抽象类,需要子类 LineCountWorker 来实现 mapreduce 方法。为了串联这些类,可以通过一些辅助类、辅助方法来实现。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def generate_inputs(data_dir):
    for name in os.listdir(data_dir):
        yield PathInputData(os.path.join(data_dir, name))


def create_workers(input_list):
    workers = []
    for input_data in input_list:
        workers.append(LineCountWorker(input_data))
    return workers


def execute(workers):
    threads = [Thread(target=w.map) for w in workers]
    for thread in threads:
        thread.start()
    for thread in threads:
        thread.join()

    first, rest = workers[0], workers[1:]
    for worker in rest:
        first.reduce(worker)
    return first.result


def mapreduce(data_dir):
    inputs = generate_inputs(data_dir)
    workers = create_workers(inputs)
    return execute(workers)


if __name__ == '__main__':
    print mapreduce("/Users/runzhliu/workspace/python-utils/data")

可以看到 mapreduce 方法实际集成了多个辅助方法,首先是创建一系列 input 然后通过创建 worker 来调度线程,最后是 excute 方法。实现起来很美,只是 mapreduce 不够通用啊!那具体是哪儿不通用呢,当需要编写其他的 InputDataWorker 子类,那就得重写那几个辅助方法了。而为了解决这个问题,需要一种更通用的方式来创建对象。在 Python 中,只允许 __init__ 这个构造器方法,而不像在 Java 或者 Scala 中可以通过构造器多态来实现。

尽管如此,Python 中 @classmethod 形式的多态可以解决这个问题。可以看一下实例。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class GenericInputData(object):
    def read(self):
        raise NotImplementedError

    @classmethod
    def generate_inputs(cls, config):
        raise NotImplementedError


class PathInputData(GenericInputData):
    def read(self):
        return open(self.path).read()

    @classmethod
    def generate_inputs(cls, config):
        data_dir = config['data_dir']
        for name in os.listdir(data_dir):
            yield cls(os.path.join(data_dir, name))


class GenericWorker(object):
    def map(self):
        raise NotImplementedError

    def reduce(self, other):
        raise NotImplementedError

    @classmethod
    def create_workers(cls, input_class, config):
        workers = []
        for input_data in input_class.generate_inputs(config):
            workers.append(cls(input_data))

        return workers


class LineCountWorker(GenericWorker):
    def reduce(self, other):
        self.result += other.result

    def map(self):
        data = self.input_data.read()
        self.result = data.count('\n')

对于像 InputData.read 那样的实例方法多态非常相似,只不过它会针对整个类,而不是从类中构建出来的对象。给抽象类 GenericInputData 提供一个 generate_inputs,可以接受一个含有配置参数的字典,而具体的子类则可以解读这些参数。同时给 GenericWorker 定义创建工作线程的辅助方法 create_workers,该方法中,input_class.generate_inputs 是个类级别的多态方法。在此处通过 cls 形式构造 GenericWorker 对象。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
def execute(workers):
    threads = [Thread(target=w.map) for w in workers]
    for thread in threads:
        thread.start()
    for thread in threads:
        thread.join()

    first, rest = workers[0], workers[1:]
    for worker in rest:
        first.reduce(worker)
    return first.result


def mapreduce(work_class, input_class, config):
    workers = work_class.create_workers(input_class, config)
    return execute(workers)
    
if __name__ == '__main__':
    print mapreduce(LineCountWorker, PathInputData, config={'data_dir': "/Users/runzhliu/workspace/python-utils/data"})

最后是重写一些辅助方法。显然此时的 mapreduce 方法可以通过传入不同的子类,而不用重写其他辅助方法来达到通用的目的。

参考资料

  1. Effective Python
警告
本文最后更新于 2017年2月1日,文中内容可能已过时,请谨慎参考。