浏览 4945 次
锁定老帖子 主题:用sqlite3实现稀疏矩阵
精华帖 (0) :: 良好帖 (0) :: 新手帖 (0) :: 隐藏帖 (0)
|
|
---|---|
作者 | 正文 |
发表时间:2009-07-29
基本思想是3元组(行坐标、列坐标和值)描述矩阵。 将3元组保存在sqlite3的内存表里。 代码如下: import sqlite3 class SparseMatrix: def __init__(self, row_count=2147483647, column_count=2147483647): self.db = sqlite3.connect(":memory:") self.db.execute("CREATE TABLE 'matrix' ('x' integer, 'y' integer, 'value' real, primary key('x', 'y'));") self.row = row_count self.column = column_count def __del__(self): self.db.close() def __getitem__(self, index): if isinstance(index, tuple): if len(index) == 2: row,column = index if row >= self.row or column >= self.column: raise IndexError cursor = self.db.execute("select value from matrix where x=? and y=?", index) value = cursor.fetchone() if value: return value[0] else: return 0.0 else: raise IndexError else: raise TypeError def __setitem__(self, index, value): row, column = index if row >= self.row or column >= self.column: raise IndexError self.db.execute("insert or replace into matrix values(?,?,?)", (row, column, value)) def __add__(self, other): #self + other if isinstance(other, SparseMatrix): m = self.copy() for r,c,v in other: m[r,c] = self[r,c] + v m.update() return m else: raise TypeError def __iadd__(self, other): #self += other if isinstance(other, SparseMatrix): if self.row != other.row or self.column != other.column: raise IndexError for row,column,value in other: self[row,column] = self[row,column] + value self.update() return self else: raise TypeError def __sub__(self, other): if isinstance(other, SparseMatrix): if self.row != other.row or self.column != other.column: raise IndexError m = self.copy() for r,c,v in other: m[r,c] = self[r,c] - v m.update() return m else: raise TypeError def __mul__(self, other): if isinstance(other, SparseMatrix): m = SparseMatrix(self.row, other.column) rows = self.getallrows() columns = other.getallcolumns() results = [] data = [] col_data = other.getcolumn(0) for r in rows: row_data = self.getrow(r) for c in columns: col_data = other.getcolumn(c) pr = row_data.__iter__() pc = col_data.__iter__() rdata = pr.next() cdata = pc.next() while True: try: if rdata[0] == cdata[0]: results.append(rdata[1] * cdata[1]) rdata = pr.next() cdata = pc.next() else: if rdata[0] > cdata[0]: cdata = pc.next() else: rdata = pr.next() except StopIteration: if results: m[r,0] = sum(results) results = list() break m.update() return m elif isinstance(other, int) or isinstance(other, float): m = SparseMatrix(self.row, self.column) for r,c,v in self: m[r,c] = v * other m.update() return m else: raise TypeError def __iter__(self): cursor = self.db.execute("select x,y,value from matrix order by x,y") for cell in cursor: yield cell def __len__(self): cursor = self.db.execute("select count(*) from matrix") return cursor.fetchone()[0] def insert(self, cells): for row,column,value in cells: self[row, column] = value self.update() def copy(self): m = SparseMatrix(self.row, self.column) for r,c,v in self: m[r,c] = v return m def getrow(self, row): cursor = self.db.execute("select y,value from matrix where x=%d order by y" % row) return cursor.fetchall() def getcolumn(self, column): cursor = self.db.execute("select x,value from matrix where y=%d order by x" % column) return cursor.fetchall() def getallrows(self): rows = self.db.execute("select distinct x from matrix order by x").fetchall() return zip(*rows)[0] def getallcolumns(self): columns = self.db.execute("select distinct y from matrix order by y").fetchall() return zip(*columns)[0] def update(self): self.db.execute("DELETE FROM matrix where value between -0.0000001 and 0.0000001") self.db.commit() 支持下表访问,支持矩阵的加法,减法和乘法运算,支持遍历。 如: m1 = SparseMatrix() m2 = SparseMatrix() m1[0,0] = 1 m2[1,1] = 2 m = m1 + m2 for row,column,value in m: print row,column,value 只是实现里功能,速度方面就比较头疼了。 做一个NxN的矩阵与N维向量的乘法用了大于6个小时(N=520000)。 还有一个支持多CPU并行乘法的版本,写的太难看,就不贴出来了。 不过有一个优点是不需要明确的知道矩阵的大小,在某些情况下还是有一些用处吧,我想。 声明:ITeye文章版权属于作者,受法律保护。没有作者书面许可不得转载。
推荐链接
|
|
返回顶楼 | |
发表时间:2009-07-29
为啥要把数据保存到sqlite?
|
|
返回顶楼 | |
发表时间:2009-07-30
主要是为了做乘法的时候比较方便。
其他的不管怎么保存都需要反复遍历才能完成乘法。 而且我实在没有想出用Python实现十字链表的方法。 |
|
返回顶楼 | |
发表时间:2009-08-19
最后修改:2009-08-19
python 不是有个 scipy的东西么。。怎么还要劳烦sqlite呢?
|
|
返回顶楼 | |
发表时间:2009-09-17
如果有注释就更好了,这个适用于哪些方面?
|
|
返回顶楼 | |