pytyrantを継承して、文字列以外のオブジェクトも入れられる用にした

こんばんわ橋本大輔です。

gumiのソーシャルアプリではDjangoから

memcached
from django.core.cache import cache
http://docs.djangoproject.com/en/1.2/topics/cache/

tokyotyrant
from pytyrant import Tyrant
http://code.google.com/p/pytyrant/

を使用しています。

 
使い分けの話は置いておきますが、
私は

cache.set(key, val, 86400)

が好きです。なぜならあんなものやそんなものが何でもそのまま突っ込めるから!
対してtokyotyrantを使用する場合は毎回文字列に変換したりしてます、これはめんどくさい。

Tyrantクラスにも同じようなインターフェースがあればいいのに、と思って作ってみました。
# ソースはほぼmemcache.pyからパクらせてもらいました!!

何でも入れれるからって何でも入れていいのかどうかって話は
入れてみて考えれば良いと思う!

# -*- coding: utf-8 -*-

import socket

try:
    import cPickle as pickle
except ImportError:
    import pickle

try:
    from cStringIO import StringIO
except ImportError:
    from StringIO import StringIO

from pytyrant import Tyrant


SERVER_MAX_KEY_LENGTH = 250
SERVER_MAX_VALUE_LENGTH = 1024*1024

DEFAULT_PORT = 1978

class TyrantForObject(Tyrant):
    @classmethod
    def open(cls, host='127.0.0.1', port=DEFAULT_PORT, **opts):
        sock = socket.socket()
        sock.connect((host, port))
        sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
        return cls(sock, **opts)

    def __init__(self, sock, **opts):
        self.sock = sock

        # Allow users to modify pickling/unpickling behavior
        self.pickleProtocol = opts.get('pickleProtocol', 0)
        self.pickler = opts.get('pickler', pickle.Pickler)
        self.unpickler = opts.get('unpickler', pickle.Unpickler)
        self.persistent_load = opts.get('pload', None)
        self.persistent_id = opts.get('pid', None)
        self.server_max_key_length = opts.get('server_max_key_length', SERVER_MAX_KEY_LENGTH)
        self.server_max_value_length = opts.get('server_max_value_length', SERVER_MAX_VALUE_LENGTH)

        #  figure out the pickler style
        file = StringIO()
        try:
            pickler = self.pickler(file, protocol=self.pickleProtocol)
            self.picklerIsKeyword = True
        except TypeError:
            self.picklerIsKeyword = False

    def _val_to_pickled_info(self, val):
        file = StringIO()
        if self.picklerIsKeyword:
            pickler = self.pickler(file, protocol = self.pickleProtocol)
        else:
            pickler = self.pickler(file, self.pickleProtocol)
        if self.persistent_id:
            pickler.persistent_id = self.persistent_id
        pickler.dump(val)
        val = file.getvalue()

        #  can not store if value length exceeds maximum
        if self.server_max_value_length != 0 and \
           len(val) >= self.server_max_value_length: raise

        return val

    def _set_object(self, key, val):
        pickled_info = self._val_to_pickled_info(val)
        self.put(key, pickled_info)

    def set_object(self, key, val):
        return self._set_object(key, val)


    def _pickled_info_to_val(self, pickled_info):
        file = StringIO(pickled_info)
        unpickler = self.unpickler(file)
        if self.persistent_load:
            unpickler.persistent_load = self.persistent_load
        val = unpickler.load()
        return val

    def _get_object(self, key):
        pickled_info = self.get(key)
        val = self._pickled_info_to_val(pickled_info)
        return val

    def get_object(self, key):
        return self._get_object(key)


if __name__ == '__main__':
    print 'START SET OBJECT TEST'
    tyrant_client = TyrantForObject.open('127.0.0.1', 1978)

    # class Tyrantの動作確認
    key001 = 'key001'
    val001 = 'val001'
    tyrant_client.put(key001, val001)
    ret001 = tyrant_client.get(key001)
    print val001
    assert val001 == ret001

    # 文字列
    key002 = 'key002'
    val002 = 'val002'
    tyrant_client.set_object(key002, val002)
    ret002 = tyrant_client.get_object(key002)
    print ret002
    assert val002 == ret002

    # int
    key003 = 'key003'
    val003 = 12345
    tyrant_client.set_object(key003, val003)
    ret003 = tyrant_client.get_object(key003)
    print ret003
    assert val003 == ret003

    # list
    key004 = 'key004'
    val004 = [1,2,3,4,5]
    tyrant_client.set_object(key004, val004)
    ret004 = tyrant_client.get_object(key004)
    print ret004
    assert val004 == ret004

    # dict
    key005 = 'key005'
    val005 = {'1':1, '2':2, '3':3}
    tyrant_client.set_object(key005, val005)
    ret005 = tyrant_client.get_object(key005)
    print ret005
    assert val005 == ret005

    # object
    key006 = 'key006'
    class C(object):
        def __cmp__(self, other):
            return cmp(self.hoge, other.hoge)
    c = C()
    c.hoge = 'hoge'
    val006 = c
    tyrant_client.set_object(key006, val006)
    ret006 = tyrant_client.get_object(key006)
    print ret006
    assert val006 == ret006