时雨小径 May the Spirit be with you

MySQL Interface in Python

之前写的推荐系统所用的更新矢量的代码是python写的, 因为在运行的过程中需要实时地更新数据库, 过多地访问数据库导致效率很低. 最近又很闲, 所以就写了一个类似接口的东西, 把数据先读进内存里, 这样操作起来就快了不少.

大概思路是这样子:

  1. 用MySQLdb连接数据库, 把涉及的表一次性读进内存.
  2. 通过python实现的查询, 插入和更新来对内存里的数据进行操作, 并标记新插入与被更新的数据.
  3. 在操作完成后将结果保存进数据库.

调试过程中碰到的主要问题有:

  1. 主键问题, 需要考虑到有的表的主键是自增而又的表需要插入时指定主键, 这里出错会导致数据与实际数据库里不一样.
  2. 外键问题, 理由同上.
  3. 查找问题, 读取表时按照主键升序排列, 之后按主键找可以用二分, 但是按其他关键字暂时没想到好的办法. 要注意在插入新数据(尤其是主键非自增情况)后要保证递增性, 不然二分会出错.
class DBError:
    def __init__(self, value):
        self.value = value

    def __str__(self):
        return repr(self.value)

class MyDBTable:
    def __init__(self, var_table_name=None, var_columns=[], var_pk=None):
        self.table_name = var_table_name
        self.columns = var_columns
        self.pk = var_pk
        self.cur = dbCon.cursor()
        self.cur.execute('SELECT ' + ','.join(
            ['`' + col + '`' for col in self.columns]) + ' FROM ' + self.table_name + ' ORDER BY ' + self.pk + ' ASC')
        self.data = [dict(zip(self.columns, row)) for row in self.cur.fetchall()]
        self.top_id = self.data[-1][self.pk] + 1 if len(self.data) > 0 else 0
        print self.table_name + ' loaded.'

    def sql_save(self):

        current_val = 0
        total_val = len(self.data)
        for val in self.data:
            progress_refresh('saving to ' + self.table_name, current_val, total_val)
            current_val += 1
            if 'flag_insert' in val and val['flag_insert'] is True:
                self.sql_insert(val)
            elif 'flag_update' in val and val['flag_update'] is True:
                self.sql_update(val)

    def sql_update(self, var_row={}):
        update_values = ','.join([col + '=' + "{:.5f}".format(Decimal(var_row[col])) for col in self.columns])
        sql_string = 'update ' + self.table_name + ' set ' + update_values + \
                     ' where ' + self.pk + ' = ' + str(var_row[self.pk])
        dbCursor.execute(sql_string)

    def sql_insert(self, var_row={}):
        insert_columns = '(' + ','.join(self.columns) + ')'
        insert_values = '(' + ','.join([str(var_row[col]) for col in self.columns]) + ')'
        sql_string = 'insert into ' + self.table_name + ' ' + insert_columns + " values " + insert_values
        dbCursor.execute(sql_string)

    def insert(self, var_row={}):

        if self.pk in var_row:
            if self.find(var_row[self.pk]) is not None:  # duplication occurs
                raise DBError('Primary key duplication occurs when inserting data: ' + var_row.__repr__())
                return

        new_row = dict.fromkeys(self.columns, 0)
        for col in self.columns:
            if col in var_row:
                new_row[col] = var_row[col]
        new_row[self.pk] = var_row[self.pk] if self.pk in var_row and var_row[self.pk] is not None else self.top_id + 1
        new_row['flag_insert'] = True
        for idx, row in enumerate(self.data):
            if row[self.pk] > new_row[self.pk]:
                self.data.insert(idx, new_row)
                break
        self.top_id = self.data[-1][self.pk] + 1

    def update(self, var_row={}):

        if self.pk in var_row:
            if self.find(var_row[self.pk]) is None:  # not found in db
                print var_row[self.pk], self.top_id
                raise DBError('Primary key not found when updating data: ' + var_row.__repr__())
                return
        idx = self.binary_search(var_row[self.pk])
        for col in var_row.keys():
            self.data[idx][col] = var_row[col]
        self.data[idx]['flag_update'] = True

    def find_one_by_col(self, col=None, key=None, columns=None):
        columns = columns if columns is not None else self.columns
        if col in self.columns:
            for row in self.data:
                if row[col] == key:
                    return {i: row[i] for i in columns}
        return None

    def find_all_by_col(self, col=None, key=None, columns=None):
        columns = columns if columns is not None else self.columns
        result_set = []
        if col in self.columns:
            for row in self.data:
                if row[col] == key:
                    ret = {i: row[i] for i in columns}
                    result_set.append(ret)
        return result_set

    def binary_search(self, key=None):
        if key is None:
            return
        low = 0
        high = len(self.data) - 1
        while low <= high:
            mid = (low + high) / 2
            val = self.data[mid][self.pk]
            if val < key:
                low = mid + 1
            elif val > key:
                high = mid - 1
            else:
                return mid
        return -1

    def find(self, key=None, columns=None):
        if key is None:
            return None
        columns = columns if columns is not None else self.columns
        idx = self.binary_search(key)
        if idx != -1:
            return {i: self.data[idx][i] for i in columns}
        return None