天天看点

十字链表类模板的实现

项目代码链接:https://github.com/weiyx15/SparseMatrix/tree/master/SparseMatrix

第一次用C++写类模板,过程十分艰辛,代码十分冗杂。写几个注意点吧:

1. 重载双目运算符用友元函数

2. 父类的友元不是子类的友元

3. 子类不能继承父类重载的=运算符

4. 有动态内存分配的类要自己实现拷贝构造函数(深拷贝)、析构函数和重载=运算符

----------------------------------------

代码

OrthoList.h

/*
* 稀疏矩阵类
* 基于十字链表的实现
* 用成员函数实现+, 用友元函数重载*
* 20180401 first edition
*/
#ifndef _ORTHOLIST_H
#define _ORTHOLIST_H
#include<iostream>
#include<vector>
#include "matrix.h"
using namespace std;

template <class NodeValueType>		// NodeValueType: 节点元素的数据类型
class Node {						// 节点类模板
public:
	int row;						// 节点行号
	int col;						// 节点列号
	NodeValueType val;				// 节点数值
	Node<NodeValueType>* right;		// 指向同行右方节点的指针
	Node<NodeValueType>* down;		// 指向同列下方节点的指针
};

template <class ListValueType>		// ListValueType: 节点元素的数据类型
class OrthoList {					// 十字链表类
public:								
		OrthoList(vector< vector<ListValueType> > mat);// 构造函数:从邻接矩阵mat计算十字链表
		OrthoList(int row, int col);// 构造函数:从空行列头指针开始构建
		OrthoList(void);			// 构造函数:构造空十字链表
		OrthoList(const OrthoList<ListValueType> &b);	// 拷贝构造函数:有动态内存分配要深拷贝
		virtual ~OrthoList();				// 析构函数(为满足动态绑定的要求,设为虚函数)

protected:
	Node<ListValueType> **rHead, **cHead;		// 行头指针和列头指针
	int nRow, nCol, nElement;					// 行数,列数,非零元素个数

private:
	bool elementWiseAddition(Node<ListValueType> *aNode);// 逐个元素加法, 被Addition调用

public:
	Node<ListValueType> **getRHead() const;	// 返回行头指针
	Node<ListValueType> **getCHead() const;	// 返回列头指针
	int getRowNumber() const;				// 返回行数
	int getColumnNumber() const;			// 返回列数
	int getElementNumber() const;			// 返回非零元素个数
	void printFull() const;					// 按邻接矩阵形式打印
	void printSparse() const;				// 按三元组(row, col, val)形式打印
	bool add(int r, int c, ListValueType v);			// 添加节点,false: (r,c)超过范围
	bool del(int r, int c);								// 删除节点,true:成功删除; false:没找到
	bool Addition(const OrthoList<ListValueType> &b);	// 矩阵加法,true:能加; false:不能
	Matrix<ListValueType> toFull() const;				// 返回Matrix对象
	OrthoList<ListValueType>& operator= (const OrthoList<ListValueType> &b);
	// 重载=运算符
	friend OrthoList<ListValueType> operator*
	(const OrthoList<ListValueType> &a, const OrthoList<ListValueType> &b)
	// 友元:重载双目运算符*(矩阵乘法)
	{
		int ra = a.getRowNumber();
		int ca = a.getColumnNumber();
		int rb = b.getRowNumber();
		int cb = b.getColumnNumber();
		if (ca != rb)									// 不符合矩阵乘法的尺寸要求
		{
			cout << "Error: matrices sizes do not match!" << endl;
			OrthoList<ListValueType> nullItem;			// 调用void构造函数就是对象名后面不加()
			return nullItem;							// 返回空对象						
		}
		OrthoList<ListValueType> ans(ra, cb);			// 构造计算结果
		Node<ListValueType> ** bRHead = b.getRHead();
		Node<ListValueType> ** bCHead = b.getCHead();
		Node<ListValueType> ** aRHead = a.getRHead();
		Node<ListValueType> ** aCHead = a.getCHead();
		Node<ListValueType> ** cAns = ans.getCHead();
		Node<ListValueType> ** rAns = ans.getRHead();
		Node<ListValueType> *p, *q;						// 链表遍历指针
		ListValueType compute = 0;						// ans[i,j]的计算结果
		bool has_value = false;							// 该项是否有值
		int i = 0, j = 0;
		for (i=0; i<ra; i++)
		{
			for (j=0; j<cb; j++)
			{
				has_value = false;						// 有值flag重置为false
				p = aRHead[i];
				q = bCHead[j];
				compute = 0;							// 乘法计算结果清零
				if (p && q)
				{
					p = p->right;
					q = q->down;
					while (p && q)
					{
						if (p->col < q->row)			// q在p后面
						{
							p = p->right;				// p往后赶
						}
						else if (p->col > q->row)		// p在q后面
						{
							q = q->down;				// q往后赶
						}
						else							// p->col == q->row
						{
							has_value = true;
							compute += p->val * q->val;	// ans[i,j] += A[i,p] * B[p,j];
							p = p->right;				// p,q一起往后赶
							q = q->down;				// p,q一起往后赶
						}
					}
				}
				if (has_value)							// 如果有非零值
				{
					ans.add(i, j, compute);
				}
			}
		}
		return ans;
	}
};

template <class ListValueType>
OrthoList <ListValueType>::OrthoList(vector< vector<ListValueType> > mat)
	:nRow(mat.size()), nCol(mat.at(0).size()), nElement(0),
	rHead(NULL), cHead(NULL)
{
	rHead = new Node<ListValueType>*[nRow];
	cHead = new Node<ListValueType>*[nCol]; 
	int i = 0, j = 0;
	for (i=0; i<nRow; i++)
	{
		rHead[i] = NULL;
	}
	for (i=0; i<nCol; i++)
	{
		cHead[i] = NULL;
	}

	Node<ListValueType> *p, *q;						// 链表遍历指针
	for (i=0; i<nRow; i++)
	{
		for (j=0; j<nCol; j++)
		{
			if (mat.at(i).at(j) != 0)				// 如果是矩阵非零元
			{
				nElement++;
				Node<ListValueType>* aNode = new Node<ListValueType>();
				aNode->row = i;						// 创建新节点
				aNode->col = j;
				aNode->val = mat.at(i).at(j);
				aNode->right = NULL;
				aNode->down = NULL;
				if (rHead[i] && cHead[j])
				{
					p = rHead[i];
					q = cHead[j];
					while (p->right != NULL)
					{
						p = p->right;
					}
					p->right = aNode;
					while (q->down != NULL)
					{
						q = q->down;
					}
					q->down = aNode;
				}
				else if (rHead[i] == NULL && cHead[j] != NULL)
				{
					rHead[i] = new Node<ListValueType>();
					rHead[i]->right = aNode;
					q = cHead[j];
					while (q->down != NULL)
					{
						q = q->down;
					}
					q->down = aNode;
				}
				else if (cHead[j] == NULL && rHead[i] != NULL)
				{
					cHead[j] = new Node<ListValueType>();
					cHead[j]->down = aNode;
					p = rHead[i];
					while (p->right != NULL)
					{
						p = p->right;
					}
					p->right = aNode;
				}
				else
				{
					rHead[i] = new Node<ListValueType>();
					rHead[i]->right = aNode;
					cHead[j] = new Node<ListValueType>();
					cHead[j]->down = aNode;
				}
			}
		}
	}
}

template <class ListValueType>
OrthoList <ListValueType>::OrthoList(int row, int col): nRow(row), nCol(col), nElement(0)
{
	rHead = new Node<ListValueType>*[nRow];
	cHead = new Node<ListValueType>*[nCol]; 
	int i = 0;
	for (i=0; i<nRow; i++)
	{
		rHead[i] = NULL;
	}
	for (i=0; i<nCol; i++)
	{
		cHead[i] = NULL;
	}
}

template <class ListValueType>
OrthoList <ListValueType>::OrthoList(void): nRow(0), nCol(0), nElement(0),
	rHead(NULL), cHead(NULL){}

template <class ListValueType>
OrthoList<ListValueType>::OrthoList(const OrthoList<ListValueType> &b): 
	nRow(b.getRowNumber()), nCol(b.getColumnNumber()), nElement(b.getElementNumber())
{
	int i = 0, j = 0;
	rHead = new Node<ListValueType>*[nRow];
	cHead = new Node<ListValueType>*[nCol]; 
	Node<ListValueType> **brHead = b.getRHead();
	Node<ListValueType> **bcHead = b.getCHead();
	for (i=0; i<nRow; i++)
	{
		rHead[i] = NULL;
	}
	for (i=0; i<nCol; i++)
	{
		cHead[i] = NULL;
	}
	Node<ListValueType> *bp, *p, *q;
	for (i=0; i<nRow; i++)
	{
		bp = brHead[i];
		if (bp)
		{
			bp = bp->right;
			while (bp)
			{
				Node<ListValueType>* aNode = new Node<ListValueType>();	// 创建新节点
				aNode->row = bp->row;						
				aNode->col = bp->col;
				aNode->val = bp->val;
				aNode->right = NULL;
				aNode->down = NULL;
				if (rHead[bp->row] && cHead[bp->col])
				{
					p = rHead[bp->row];
					q = cHead[bp->col];
					while (p->right != NULL)
					{
						p = p->right;
					}
					p->right = aNode;
					while (q->down != NULL)
					{
						q = q->down;
					}
					q->down = aNode;
				}
				else if (rHead[bp->row] == NULL && cHead[bp->col] != NULL)
				{
					rHead[bp->row] = new Node<ListValueType>();
					rHead[bp->row]->right = aNode;
					q = cHead[bp->col];
					while (q->down != NULL)
					{
						q = q->down;
					}
					q->down = aNode;
				}
				else if (cHead[bp->col] == NULL && rHead[bp->row] != NULL)
				{
					cHead[bp->col] = new Node<ListValueType>();
					cHead[bp->col]->down = aNode;
					p = rHead[bp->row];
					while (p->right != NULL)
					{
						p = p->right;
					}
					p->right = aNode;
				}
				else
				{
					rHead[bp->row] = new Node<ListValueType>();
					rHead[bp->row]->right = aNode;
					cHead[bp->col] = new Node<ListValueType>();
					cHead[bp->col]->down = aNode;
				}
				bp = bp->right;
			}
		}
	}
}

template <class ListValueType>
OrthoList <ListValueType>::~OrthoList()
{
	int i = 0;
	Node<ListValueType> *p, *q;
	for (i=0; i<nRow; i++)
	{
		p = rHead[i];
		while(p != NULL)
		{
			q = p;
			p = p->right;
			delete q;
			q = NULL;
		}
	}
	delete rHead;
	delete cHead;
	rHead = NULL;
	cHead = NULL;
}

template <class ListValueType>
OrthoList<ListValueType>& OrthoList<ListValueType>::operator= (const OrthoList<ListValueType> &b)
{
	nRow = b.getRowNumber();
	nCol = b.getColumnNumber();
	nElement = b.getElementNumber();
	rHead = new Node<ListValueType>*[nRow];
	cHead = new Node<ListValueType>*[nCol]; 
	Node<ListValueType> **brHead = b.getRHead();
	Node<ListValueType> **bcHead = b.getCHead();
	int i = 0, j = 0;
	for (i=0; i<nRow; i++)
	{
		rHead[i] = NULL;
	}
	for (i=0; i<nCol; i++)
	{
		cHead[i] = NULL;
	}
	Node<ListValueType> *bp, *p, *q;
	for (i=0; i<nRow; i++)
	{
		bp = brHead[i];
		if (bp)
		{
			bp = bp->right;
			while (bp)
			{
				Node<ListValueType>* aNode = new Node<ListValueType>();	// 创建新节点
				aNode->row = bp->row;						
				aNode->col = bp->col;
				aNode->val = bp->val;
				aNode->right = NULL;
				aNode->down = NULL;
				if (rHead[bp->row] && cHead[bp->col])
				{
					p = rHead[bp->row];
					q = cHead[bp->col];
					while (p->right != NULL)
					{
						p = p->right;
					}
					p->right = aNode;
					while (q->down != NULL)
					{
						q = q->down;
					}
					q->down = aNode;
				}
				else if (rHead[i] == NULL && cHead[j] != NULL)
				{
					rHead[bp->row] = new Node<ListValueType>();
					rHead[bp->row]->right = aNode;
					q = cHead[bp->col];
					while (q->down != NULL)
					{
						q = q->down;
					}
					q->down = aNode;
				}
				else if (cHead[bp->col] == NULL && rHead[bp->row] != NULL)
				{
					cHead[bp->col] = new Node<ListValueType>();
					cHead[bp->col]->down = aNode;
					p = rHead[bp->row];
					while (p->right != NULL)
					{
						p = p->right;
					}
					p->right = aNode;
				}
				else
				{
					rHead[bp->row] = new Node<ListValueType>();
					rHead[bp->row]->right = aNode;
					cHead[bp->col] = new Node<ListValueType>();
					cHead[bp->col]->down = aNode;
				}
				bp = bp->right;
			}
		}
	}
	return *this;
}

template <class ListValueType>
Node<ListValueType> **OrthoList <ListValueType>::getRHead() const
{
	return rHead;
}

template <class ListValueType>
Node<ListValueType> **OrthoList <ListValueType>::getCHead() const
{
	return cHead;
}

template <class ListValueType>
int OrthoList <ListValueType>::getRowNumber() const
{
	return nRow;
}

template <class ListValueType>
int OrthoList <ListValueType>::getColumnNumber() const
{
	return nCol;
}

template <class ListValueType>
int OrthoList <ListValueType>::getElementNumber() const
{
	return nElement;
}

template <class ListValueType>
void OrthoList <ListValueType>::printFull() const
{
	int i = 0, j = 0;
	Node<ListValueType> *p;
	for (i=0; i<nRow; i++)
	{
		vector<int> line(nCol,0);						// 存储矩阵的一行, 初始化为全0
		if (rHead[i]!=NULL)
		{
			p = rHead[i]->right;
			while (p!=NULL)
			{
				line.at(p->col) = p->val;
				p = p->right;
			}
		}
		for (j=0; j<nCol; j++)
		{
			cout << line.at(j) << " ";
		}
		cout << endl;
	}
}

template <class ListValueType>
void OrthoList <ListValueType>::printSparse() const
{
	int i = 0;
	Node<ListValueType> *p;
	for (i=0; i<nRow; i++)
	{
		if (rHead[i]!=NULL)
		{
			p = rHead[i]->right;
			while (p!=NULL)
			{
				cout << "( " << p->row << ", " << p->col << ", " << p->val << " )" << endl;
				// 输出(row, col, val)三元组
				p = p->right;
			}
		}
	}
}

// 在(r,c)位置加入新节点,若(r,c)位置已有节点,则替换为新值
template <class ListValueType>
bool OrthoList <ListValueType>::add(int r, int c, ListValueType v)
{
	if (r>=nRow || c>=nCol || r<0 || c<0)
	{
		return false;
	}
	Node<ListValueType> *aNode = new Node<ListValueType>();
	aNode -> row = r;
	aNode -> col = c;
	aNode -> val = v;
	aNode -> right = NULL;
	aNode -> down = NULL;
	bool r_flag = false;								// 是否在行中已经添加过aNode的标志
	bool c_flag = false;								// 是否在列中已经修改连接关系
	Node<ListValueType> *p, *q, *tmpNode;				// 遍历链表指针
	// 向行中插入新元素或修改旧元素
	if (rHead[r] && cHead[c])
	{
		p = rHead[r]->right;
		q = cHead[c]->down;
		if (p->col > c)									// c的列值小于rHead[r]行首的列值
		{
			rHead[r]->right = aNode;
			aNode->right = p;
			nElement++;
			r_flag = true;
			// 修改列的连接关系
			if (q->row > r)
			{
				cHead[c]->down = aNode;
				aNode->down = q;
				c_flag = true;
			}
			else
			{
				while (q->down != NULL)
				{
					if (q->row < r && q->down->row > r)
					{
						tmpNode = q->down;
						q->down = aNode;
						aNode->down = tmpNode;
						c_flag = true;
						break;
					}
					else
					{
						q = q->down;
					}
				}
				if (c_flag == false)
				{
					q->down = aNode;
				}
			}
		}
		else if (p->col == c)					// 相等则替换
		// 因为while循环条件: p->right != NULL, 所以每行第0个节点需要单独讨论
		{
			p->val = v;
			delete aNode;
			r_flag = true;
		}
		else
		{
			while (p->right != NULL)
			{
				if (p->right->col == c)				// 相等则替换
				// 考虑下一个节点, 这样写是为了兼容while循环条件: p->right != NULL
				{
					p -> right -> val = v;
					delete aNode;							// 此时无须添加新节点
					r_flag = true;
					break;
				}
				else if (p->col < c && p->right->col >c)	// 不等则插入
				{
					tmpNode = p->right;
					p->right = aNode;
					aNode->right = tmpNode;
					nElement++;
					r_flag = true;
					// 修改列的连接关系
					if (q->row > r)
					{
						cHead[c]->down = aNode;
						aNode->down = q;
						c_flag = true;
					}
					else
					{
						while (q->down != NULL)
						{
							if (q->row < r && q->down->row >r)
							{
								tmpNode = q->down;
								q->down = aNode;
								aNode->down = tmpNode;
								c_flag = true;
								break;
							}
							else
							{
								q = q->down;
							}
						}
						if (c_flag == false)
						{
							q->down = aNode;
						}
					}
					break;
				}
				else
				{
					p = p->right;
				}
			}
			if (r_flag == false)
			{
				p->right = aNode;	// 到达while循环结尾说明aNode->c大于所有p->c
				nElement++;
				// 修改列的连接关系
				if (q->row > r)
				{
					cHead[c]->down = aNode;
					aNode->down = q;
					c_flag = true;
				}
				else
				{
					while (q->down != NULL)
					{
						if (q->row < r && q->down->row >r)
						{
							tmpNode = q->down;
							q->down = aNode;
							aNode->down = tmpNode;
							c_flag = true;
							break;
						}
						else
						{
							q = q->down;
						}
					}
					if (c_flag == false)
					{
						q->down = aNode;
					}
				}
			}
		}
	}
	else if (rHead[r] == NULL && cHead[c] != NULL)
	{
		rHead[r] = new Node<ListValueType>();
		rHead[r]->right = aNode;
		q = cHead[c];
		int c_flag = false;								// 是否已经在列中修改连接关系
		// 修改列的连接关系
		if (q->row > r)
		{
			cHead[c]->down = aNode;
			aNode->down = q;
			c_flag = true;
		}
		else
		{
			while (q->down != NULL)
			{
				if (q->row < r && q->down->row >r)
				{
					tmpNode = q->down;
					q->down = aNode;
					aNode->down = tmpNode;
					c_flag = true;
					break;
				}
				else
				{
					q = q->down;
				}
			}
			if (c_flag == false)
			{
				q->down = aNode;
			}
		}
	}
	else if (cHead[c] == NULL && rHead[r] != NULL)
	{
		cHead[c] = new Node<ListValueType>();
		cHead[c]->down = aNode;
		p = rHead[r];
		int r_flag = false;								// 是否已经在行中修改连接关系
		// 修改行的连接关系
		if (p->col > c)
		{
			rHead[r]->right = aNode;
			aNode->right = p;
			r_flag = true;
		}
		else
		{
			while (p->right != NULL)
			{
				if (p->col < c && p->right->col >c)
				{
					tmpNode = p->right;
					p->right = aNode;
					aNode->right = tmpNode;
					r_flag = true;
					break;
				}
				else
				{
					p = p->right;
				}
			}
			if (r_flag == false)
			{
				p->right = aNode;
			}
		}
	}
	else
	{
		rHead[r] = new Node<ListValueType>();
		rHead[r]->right = aNode;
		cHead[c] = new Node<ListValueType>();
		cHead[c]->down = aNode;
	}
	return true;
}

// 稀疏矩阵this与单个元素aNode(r,c,v)相加, this(r,c)不存在则建新, 存在则相加
// 复用bool OrthoList <ListValueType>::add(int r, int c, ListValueType v)的代码进行修改
template <class ListValueType>
bool OrthoList <ListValueType>::elementWiseAddition(Node<ListValueType> *aNode)
{
	int r = aNode->row;
	int c = aNode->col;
	ListValueType v = aNode->val;
	if (r>=nRow || c>=nCol || r<0 || c<0)
	{
		return false;									// 尺寸溢出,返回false
	}
	bool r_flag = false;								// 是否在行中已经添加过aNode的标志
	bool c_flag = false;								// 是否在列中已经修改连接关系
	Node<ListValueType> *p, *q, *tmpNode;				// 遍历链表指针
	// 向行中插入新元素或修改旧元素
	if (rHead[r] && cHead[c])
	{
		p = rHead[r]->right;
		q = cHead[c]->down;
		if (p->col > c)									// c的列值小于rHead[r]行首的列值
		{
			rHead[r]->right = aNode;
			aNode->right = p;
			r_flag = true;
			// 修改列的连接关系
			if (q->row > r)
			{
				cHead[c]->down = aNode;
				aNode->down = q;
				c_flag = true;
			}
			else
			{
				while (q->down != NULL)
				{
					if (q->row < r && q->down->row > r)
					{
						tmpNode = q->down;
						q->down = aNode;
						aNode->down = tmpNode;
						c_flag = true;
						break;
					}
					else
					{
						q = q->down;
					}
				}
				if (c_flag == false)
				{
					q->down = aNode;
				}
			}
		}
		else if (p->col == c)					// 相等则替换
		// 因为while循环条件: p->right != NULL, 所以每行第0个节点需要单独讨论
		{
			p->val += v;						// 将替换改为相加
			delete aNode;
			r_flag = true;
		}
		else
		{
			while (p->right != NULL)
			{
				if (p->right->col == c)				// 相等则替换
				// 考虑下一个节点, 这样写是为了兼容while循环条件: p->right != NULL
				{
					p -> right -> val += v;
					delete aNode;							// 此时无须添加新节点
					aNode = NULL;
					r_flag = true;
					break;
				}
				else if (p->col < c && p->right->col >c)	// 不等则插入
				{
					tmpNode = p->right;
					p->right = aNode;
					aNode->right = tmpNode;
					r_flag = true;
					// 修改列的连接关系
					if (q->row > r)
					{
						cHead[c]->down = aNode;
						aNode->down = q;
						c_flag = true;
					}
					else
					{
						while (q->down != NULL)
						{
							if (q->row < r && q->down->row >r)
							{
								tmpNode = q->down;
								q->down = aNode;
								aNode->down = tmpNode;
								c_flag = true;
								break;
							}
							else
							{
								q = q->down;
							}
						}
						if (c_flag == false)
						{
							q->down = aNode;
						}
					}
					break;
				}
				else
				{
					p = p->right;
				}
			}
			if (r_flag == false)
			{
				p->right = aNode;	// 到达while循环结尾说明aNode->c大于所有p->c
				// 修改列的连接关系
				if (q->row > r)
				{
					cHead[c]->down = aNode;
					aNode->down = q;
					c_flag = true;
				}
				else
				{
					while (q->down != NULL)
					{
						if (q->row < r && q->down->row >r)
						{
							tmpNode = q->down;
							q->down = aNode;
							aNode->down = tmpNode;
							c_flag = true;
							break;
						}
						else
						{
							q = q->down;
						}
					}
					if (c_flag == false)
					{
						q->down = aNode;
					}
				}
			}
		}
	}
	else if (rHead[r] == NULL && cHead[c] != NULL)
	{
		rHead[r] = new Node<ListValueType>();
		rHead[r]->right = aNode;
		q = cHead[c];
		int c_flag = false;								// 是否已经在列中修改连接关系
		// 修改列的连接关系
		if (q->row > r)
		{
			cHead[c]->down = aNode;
			aNode->down = q;
			c_flag = true;
		}
		else
		{
			while (q->down != NULL)
			{
				if (q->row < r && q->down->row >r)
				{
					tmpNode = q->down;
					q->down = aNode;
					aNode->down = tmpNode;
					c_flag = true;
					break;
				}
				else
				{
					q = q->down;
				}
			}
			if (c_flag == false)
			{
				q->down = aNode;
			}
		}
	}
	else if (cHead[c] == NULL && rHead[r] != NULL)
	{
		cHead[c] = new Node<ListValueType>();
		cHead[c]->down = aNode;
		p = rHead[r];
		int r_flag = false;								// 是否已经在行中修改连接关系
		// 修改行的连接关系
		if (p->col > c)
		{
			rHead[r]->right = aNode;
			aNode->right = p;
			r_flag = true;
		}
		else
		{
			while (p->right != NULL)
			{
				if (p->col < c && p->right->col >c)
				{
					tmpNode = p->right;
					p->right = aNode;
					aNode->right = tmpNode;
					r_flag = true;
					break;
				}
				else
				{
					p = p->right;
				}
			}
			if (r_flag == false)
			{
				p->right = aNode;
			}
		}
	}
	else
	{
		rHead[r] = new Node<ListValueType>();
		rHead[r]->right = aNode;
		cHead[c] = new Node<ListValueType>();
		cHead[c]->down = aNode;
	}
	return true;
}

template <class ListValueType>
bool OrthoList <ListValueType>::del(int r, int c)
{
	if (r>=nRow || c>= nCol || r<0 || c<0)						// 下标超出索引范围
	{
		return false;
	}
	Node<ListValueType> *p, *q;									// 链表遍历指针
	Node<ListValueType> *tmpNode;								// 用于删除节点的中间变量
	p = rHead[r];
	q = cHead[c];
	if (p == NULL)												// 该行为空
	{
		return false;
	}
	else
	{
		while (p->right != NULL)
		{
			if (p->right->col == c)								// 找到欲删除元素
			{
				// 更改列的连接关系
				while (q->down->row != r)						// 遍历cHead[c]找row==r				
				{
					q = q->down;
				}
				if (q == cHead[c] && q->down->down == NULL)		// 删掉的是该列唯一的元素
				{
					delete cHead[c];
					cHead[c] = NULL;
				}
				else
				{
					q->down = q->down->down;					// 修改列的连接关系
				}
				if (p == rHead[r] && p->right->right == NULL)	// 删掉的是该行的唯一的元素
				{
					delete rHead[r]->right;
					rHead[r]->right = NULL;
					delete rHead[r];
					rHead[r] = NULL;
				}
				else
				{
					tmpNode = p->right;
					p->right = p->right->right;					// 修改行的连接关系
					delete tmpNode;
					tmpNode = NULL;
				}
				nElement--;
				return true;
			}
			p  = p->right;
		}
		return false;											// 找不到欲删除的元素
	}
}

template <class ListValueType>
bool OrthoList<ListValueType>::Addition(const OrthoList<ListValueType> &b)
{
	int rb = b.getRowNumber();
	int cb = b.getColumnNumber();
	Node<ListValueType> ** bRHead = b.getRHead();
	Node<ListValueType> ** bCHead = b.getCHead();
	if (rb != nRow || cb != nCol)
	{
		return false;								// 大小不匹配: 返回false
	}
	Node<ListValueType> *pb;						// 遍历b的指针
	int i = 0;
	for (i=0; i<nRow; i++)
	{
		pb = bRHead[i];
		if (pb != NULL)
		{
			pb = pb->right;
			while (pb != NULL)
			{
				Node<ListValueType> *aNode = new Node<ListValueType>();
				aNode -> row = pb -> row;
				aNode -> col = pb -> col;
				aNode -> val = pb -> val;
				if (!elementWiseAddition(aNode))	// 调用逐个加法私有方法
				{
					return false;
				}
				pb = pb->right;
			}
		}
	}
	return true;
}

template <class ListValueType>
Matrix<ListValueType> OrthoList <ListValueType>::toFull() const
{
	Matrix<ListValueType> mat(nRow, nCol);
	int i = 0;
	Node<ListValueType> *p;
	for (i=0; i<nRow; i++)
	{
		if (rHead[i]!=NULL)
		{
			p = rHead[i]->right;
			while (p!=NULL)
			{
				mat(p->row, p->col) = p->val;
				p = p->right;
			}
		}
	}
	return mat;
}


#endif
           

matrix.h

/*
* 稠密矩阵类
* 基于一维数组的实现
* 用友元函数重载+/-/*
* 20180401 first edition
*/
#ifndef _MATRIX_H
#define _MATRIX_H
#include <iostream>
#include <cassert>
using namespace std;
template <class T>
class Matrix {
    public:
        int m, n;
    private:
        T* data;
    public:
        //constructor
        Matrix(int r=0, int c=0): m(r), n(c) {
            if (r>0 && c>0) {
                data = new T[m*n];
                if (data==NULL) {
                    cerr<<"Allocate memory failed!"<<endl;
                    exit(1);
                }
                memset(data, 0, m*n*sizeof(T));
            } else
                data = NULL;
        }
		// copy constructor
		Matrix(const Matrix<T> &a): m(a.m), n(a.n)	// 深拷贝
		{
			int i = 0, j = 0;
			if (m>0 && n>0) {
                data = new T[m*n];
                if (data==NULL) {
                    cerr<<"Allocate memory failed!"<<endl;
                    exit(1);
                }
				else
				{
					T* aData = a.getData();
					for (i=0; i<m; i++)
					{
						for (j=0; j<n; j++)
						{
							data[i+j*m] = aData[i+j*m];
						}
					}
				}
            } else
                data = NULL;
        }
        //deconstructor
        ~Matrix() {
            m = n = 0;
            delete[] data;
            data = NULL;
        }
		// get pointer: data
		T* getData() const
		{
			return data;
		}
        //entry access
        T& operator() (int i, int j) const{
            assert(i>=0 && i<m && j>=0 && j<n);
            return data[i+j*m];
        }
        //assign
        Matrix<T>& operator= (const Matrix<T> &a) {
            m = a.m;    n = a.n;
			if (data != NULL)
			{
				delete[] data;
				data = NULL;
			}
            data = new T[m*n*sizeof(T)];
            if (data==NULL) {
                cerr<<"Allocate memory failed!"<<endl;
                exit(1);
            }
            for (int i=0; i<m; i++)
                for (int j=0; j<n; j++)
                    data[i+j*m] = a(i, j);
            return *this;
        }
        //display
        void display() {
            cout<<"("<<m<<", "<<n<<"):"<<endl;
            for (int i=0; i<m; i++) {
                for (int j=0; j<n; j++)
                    cout<<data[i+j*m]<<"\t";
                cout<<endl;
            }
        }
		// 友元:双目操作符+,非成员函数
		friend Matrix<T> operator+ (const Matrix<T>& a, const Matrix<T>& b)
		{
			assert(a.m==b.m && a.n==b.n);
			Matrix<T> r(a.m, a.n);
			for (int i=0; i<a.m; i++)
				for (int j=0; j<a.n; j++)
					r(i, j) = a(i, j) + b(i, j);
			return r;
		}
		// 友元:双目操作符-,非成员函数
		friend Matrix<T> operator- (const Matrix<T>& a, const Matrix<T>& b)
		{
			assert(a.m==b.m && a.n==b.n);
			Matrix<T> r(a.m, a.n);
			for (int i=0; i<a.m; i++)
				for (int j=0; j<a.n; j++)
					r(i, j) = a(i, j) - b(i, j);
			return r;
		}
		// 友元:双目操作符*,非成员函数
		friend Matrix<T> operator* (const Matrix<T>& a, const Matrix<T>& b)
		{
			assert(a.n==b.m);
			Matrix<T> r(a.m, b.n);
			for (int i=0; i<a.m; i++)
				for (int j=0; j<b.n; j++)
					for (int k=0; k<a.n; k++)
						r(i, j) += a(i, k) * b(k, j);
			return r;
		}

};

#endif
           

complex.h

/*
* 复数类
* 用友元函数重载复数的+/-/*运算
* 20180401 first edition
*/
#ifndef _COMPLEX_H
#define _COMPLEX_H
#include <iostream>
using namespace std;
class Complex {
public:
	Complex(double r = 0.0, double i = 0.0) : real(r), imag(i) { }
	friend Complex operator+(const Complex &c1, const Complex &c2);
	friend Complex operator-(const Complex &c1, const Complex &c2);
	friend Complex operator*(const Complex &c1, const Complex &c2);
	Complex & operator+=(const Complex &c);
	friend ostream & operator<<(ostream &out, const Complex &c);
private:
	double real; //复数实部
	double imag; //复数虚部
};

Complex operator+(const Complex &c1, const Complex &c2){
	return Complex(c1.real+c2.real, c1.imag+c2.imag);
}
Complex operator-(const Complex &c1, const Complex &c2){
	return Complex(c1.real-c2.real, c1.imag-c2.imag);
}
Complex operator*(const Complex &c1, const Complex &c2){
	return Complex(c1.real*c2.real-c1.imag*c2.imag, c1.imag*c2.real+c1.real*c2.imag);
}
Complex & Complex::operator+=(const Complex &c)
{
	*this = *this + c;
	return *this;
}
ostream & operator<<(ostream &out, const Complex &c){
	out << "(" << c.real << ", " << c.imag << ")";
return out;
}

#endif
           

ComplexOrthoList.h

/*
* 复数稀疏矩阵类
* OrthoList的派生类
* 增加了从文件读入的构造函数
*/
#ifndef _COMPLEXORTHOLIST_H
#define _COMPLEXORTHOLIST_H
#include "OrthoList.h"
#include "matrix.h"
#include<fstream>
#include "complex.h"

class ComplexOrthoList : public OrthoList<Complex>
{
public:
	ComplexOrthoList(ifstream &fin);				// 构造函数:从文件读入
	ComplexOrthoList(void);							// 默认构造函数
	ComplexOrthoList(const ComplexOrthoList &b);	// 拷贝构造函数
	ComplexOrthoList & operator =(const ComplexOrthoList &b);	
	// 重载=运算符, 继承类不能继承基类的=运算符
	friend ComplexOrthoList operator *(const ComplexOrthoList &a, const ComplexOrthoList &b)
	// 重载*运算符:基类的友元函数不是继承类的友元函数
	{
		int ra = a.getRowNumber();
		int ca = a.getColumnNumber();
		int rb = b.getRowNumber();
		int cb = b.getColumnNumber();
		if (ca != rb)									// 不符合矩阵乘法的尺寸要求
		{
			cout << "Error: matrices sizes do not match!" << endl;
			ComplexOrthoList nullItem;					// 调用void构造函数就是对象名后面不加()
			return nullItem;							// 返回空对象						
		}
		ComplexOrthoList ans;							// 构造计算结果
		Node<Complex> ** bRHead = b.getRHead();
		Node<Complex> ** bCHead = b.getCHead();
		Node<Complex> ** aRHead = a.getRHead();
		Node<Complex> ** aCHead = a.getCHead();
		Node<Complex> ** cAns = ans.getCHead();
		Node<Complex> ** rAns = ans.getRHead();
		Node<Complex> *p, *q;						// 链表遍历指针
		Complex compute = 0;						// ans[i,j]的计算结果
		bool has_value = false;						// 该项是否有值
		int i = 0, j = 0;
		for (i=0; i<ra; i++)
		{
			for (j=0; j<cb; j++)
			{
				has_value = false;						// 有值flag重置为false
				p = aRHead[i];
				q = bCHead[j];
				compute = 0;							// 乘法计算结果清零
				if (p && q)
				{
					p = p->right;
					q = q->down;
					while (p && q)
					{
						if (p->col < q->row)			// q在p后面
						{
							p = p->right;				// p往后赶
						}
						else if (p->col > q->row)		// p在q后面
						{
							q = q->down;				// q往后赶
						}
						else							// p->col == q->row
						{
							has_value = true;
							compute += p->val * q->val;	// ans[i,j] += A[i,p] * B[p,j];
							p = p->right;				// p,q一起往后赶
							q = q->down;				// p,q一起往后赶
						}
					}
				}
				if (has_value)							// 如果有非零值
				{
					ans.add(i, j, compute);
				}
			}
		}
		return ans;
	}
};

ComplexOrthoList::ComplexOrthoList(ifstream &fin)
{
	fin >> nRow >> nCol;				// 首行2个整数是矩阵尺寸
	nElement = 0;						// 初始化nElement
	rHead = new Node<Complex>*[nRow];	// 初始化行头指针
	cHead = new Node<Complex>*[nCol];	// 初始化列头指针
	int i = 0, j = 0;					// 接下来每行有4个数,头两个是坐标,注意从1开始不是从0开始
	for (i=0; i<nRow; i++)
	{
		rHead[i] = NULL;
	}
	for (i=0; i<nCol; i++)
	{
		cHead[i] = NULL;
	}
	double Real = 0, Imag = 0;			// 后2个是实部/虚部
	Node<Complex> *p, *q;				// 链表遍历指针
	while (fin >> i >> j >> Real >> Imag)
	{
		i--; j--;						// 下标化为从0开始的形式
		Complex com(Real,Imag);
		Node<Complex>* aNode = new Node<Complex>();
		aNode->row = i;					// 创建新节点
		aNode->col = j;
		aNode->val = com;
		aNode->right = NULL;
		aNode->down = NULL;
		if (rHead[i] && cHead[j])
		{
			p = rHead[i];
			q = cHead[j];
			while (p->right != NULL)
			{
				p = p->right;
			}
			p->right = aNode;
			while (q->down != NULL)
			{
				q = q->down;
			}
			q->down = aNode;
		}
		else if (rHead[i] == NULL && cHead[j] != NULL)
		{
			rHead[i] = new Node<Complex>();
			rHead[i]->right = aNode;
			q = cHead[j];
			while (q->down != NULL)
			{
				q = q->down;
			}
			q->down = aNode;
		}
		else if (cHead[j] == NULL && rHead[i] != NULL)
		{
			cHead[j] = new Node<Complex>();
			cHead[j]->down = aNode;
			p = rHead[i];
			while (p->right != NULL)
			{
				p = p->right;
			}
			p->right = aNode;
		}
		else
		{
			rHead[i] = new Node<Complex>();
			rHead[i]->right = aNode;
			cHead[j] = new Node<Complex>();
			cHead[j]->down = aNode;
		}
	}
}

ComplexOrthoList::ComplexOrthoList(void):OrthoList<Complex>(){}

ComplexOrthoList::ComplexOrthoList(const ComplexOrthoList &b):OrthoList<Complex>(b){}

ComplexOrthoList &ComplexOrthoList::operator =(const ComplexOrthoList &b)
{
	nRow = b.getRowNumber();
	nCol = b.getColumnNumber();
	nElement = b.getElementNumber();
	rHead = new Node<Complex>*[nRow];
	cHead = new Node<Complex>*[nCol]; 
	Node<Complex> **brHead = b.getRHead();
	Node<Complex> **bcHead = b.getCHead();
	int i = 0, j = 0;
	for (i=0; i<nRow; i++)
	{
		rHead[i] = NULL;
	}
	for (i=0; i<nCol; i++)
	{
		cHead[i] = NULL;
	}
	Node<Complex> *bp, *p, *q;
	for (i=0; i<nRow; i++)
	{
		bp = brHead[i];
		if (bp)
		{
			bp = bp->right;
			while (bp)
			{
				Node<Complex>* aNode = new Node<Complex>();	// 创建新节点
				aNode->row = bp->row;						
				aNode->col = bp->col;
				aNode->val = bp->val;
				aNode->right = NULL;
				aNode->down = NULL;
				if (rHead[bp->row] && cHead[bp->col])
				{
					p = rHead[bp->row];
					q = cHead[bp->col];
					while (p->right != NULL)
					{
						p = p->right;
					}
					p->right = aNode;
					while (q->down != NULL)
					{
						q = q->down;
					}
					q->down = aNode;
				}
				else if (rHead[i] == NULL && cHead[j] != NULL)
				{
					rHead[bp->row] = new Node<Complex>();
					rHead[bp->row]->right = aNode;
					q = cHead[bp->col];
					while (q->down != NULL)
					{
						q = q->down;
					}
					q->down = aNode;
				}
				else if (cHead[bp->col] == NULL && rHead[bp->row] != NULL)
				{
					cHead[bp->col] = new Node<Complex>();
					cHead[bp->col]->down = aNode;
					p = rHead[bp->row];
					while (p->right != NULL)
					{
						p = p->right;
					}
					p->right = aNode;
				}
				else
				{
					rHead[bp->row] = new Node<Complex>();
					rHead[bp->row]->right = aNode;
					cHead[bp->col] = new Node<Complex>();
					cHead[bp->col]->down = aNode;
				}
				bp = bp->right;
			}
		}
	}
	return *this;
}

#endif
           

test.cpp

#include "OrthoList.h"
#include "matrix.h"
#include "complex.h"
#include "ComplexOrthoList.h"
#include<ctime>

const int TEST_TIME = 1;							// 重复测试取平均的次数

int main()
{
	vector< vector<int> > mat, mat1;
	vector<int> aLine;
	// 第一行:0,1,0
	aLine.push_back(0);
	aLine.push_back(1);
	aLine.push_back(0);
	mat.push_back(aLine);
	aLine.clear();
	// 第二行:2,0,0
	aLine.push_back(2);
	aLine.push_back(0);
	aLine.push_back(0);
	mat.push_back(aLine);
	aLine.clear();
	// 第三行:0,0,3
	aLine.push_back(0);
	aLine.push_back(0);
	aLine.push_back(3);
	mat.push_back(aLine);
	
	cout << "**********************" << endl;
	cout << "****OrthoList test****" << endl;
	cout << "**********************" << endl;
	OrthoList<int> testOList(mat);
	testOList.printSparse();
	testOList.printFull();
	cout << "Add elements" << endl;
	testOList.add(0,0,-1);
	testOList.add(0,1,-2);
	testOList.add(0,2,-3);
	testOList.add(1,0,-4);
	testOList.add(1,1,-5);
	testOList.add(1,2,-6);
	testOList.add(2,0,-7);
	testOList.add(2,1,-8);
	testOList.add(2,2,-9);
	testOList.printSparse();
	testOList.printFull();
	cout << "Delete elements" << endl;
	testOList.del(0,0);
	testOList.del(1,1);
	testOList.del(1,0);
	testOList.printSparse();
	testOList.printFull();
	OrthoList<int> oList1(testOList);
	cout << " New sparse matrix: " << endl;
	oList1.printFull();
	cout << " matrix1 + matrix2: " << endl;
	testOList.Addition(oList1);
	testOList.printSparse();
	testOList.printFull();
	cout << "matrix1 * martrix2: " << endl;
	OrthoList<int> oList2(oList1);
	oList2 = testOList * oList1;
	oList2.printFull();

	cout << "**********************" << endl;
	cout << "*OrthoList test done**" << endl;
	cout << "**********************" << endl;
	cout << endl;
	cout << "----------------------" << endl;
	cout << "-----matrix test------" << endl;
	cout << "----------------------" << endl;
	int m = 2, n = 2;
    Matrix<int> a = oList1.toFull();
    Matrix<int> b = testOList.toFull();
    cout<<"Matrix a"; a.display(); cout<<endl;
    cout<<"Matrix b"; b.display(); cout<<endl;
    Matrix<int> c;
	c = (a+b);
    cout<<"a+b: "; c.display(); cout<<endl;
    Matrix<int> d;
	d = (a*b);
    cout<<"a*b:"; d.display(); 
	cout << "----------------------" << endl;
	cout << "---matrix test done---" << endl;
	cout << "----------------------" << endl << endl;

	cout << "@@@@@@@@@@@@@@@@@@@@@@" << endl;
	cout << "@@@ run time test @@@@" << endl;
	cout << "@@@@@@@@@@@@@@@@@@@@@@" << endl;
	ifstream fin("Real.txt");					// 数据文件
	assert(fin);								// 文件是否正常打开
	ComplexOrthoList sY(fin);					// sY: sparse Y matrix
	fin.close();
	fin.open("Real_2.txt");
	assert(fin);
	ComplexOrthoList sY1(fin);
	fin.close();
	ComplexOrthoList sY2(sY);
	Matrix<Complex> dY = sY.toFull();			// dY: dense Y matrix
	Matrix<Complex> dY1 = sY1.toFull();
	Matrix<Complex> dY2(dY);
	int i = 0;
	clock_t t0;
	// 测试矩阵加法
	t0 = clock();
	for (i=0; i<TEST_TIME; i++)
	{
		sY.Addition(sY1);
	}
	clock_t run_sparse_add = clock()-t0;
	run_sparse_add /= TEST_TIME;
	cout << TEST_TIME << " times SPARSE MATRIX addition: " 
		<< run_sparse_add << " ms" << endl;
	t0 = clock();
	for (i=0; i<TEST_TIME; i++)
	{
		dY2 = dY + dY1;
	}
	clock_t run_dense_add = clock()-t0;
	run_dense_add /= TEST_TIME;
	cout << TEST_TIME << " times DENSE MATRIX addition: " 
		<< run_dense_add << " ms" << endl;
	// 测试矩阵乘法
	t0 = clock();
	for (i=0; i<TEST_TIME; i++)
	{
		sY2 = sY * sY1;
	}
	clock_t run_sparse_mul = clock()-t0;
	run_sparse_mul /= TEST_TIME;
	cout << TEST_TIME << " times SPARSE MATRIX multiplication: " 
		<< run_sparse_mul << " ms" << endl;
	t0 = clock();
	for (i=0; i<TEST_TIME; i++)
	{
		dY2 = dY * dY1;
	}
	clock_t run_dense_mul = clock()-t0;
	run_dense_mul /= TEST_TIME;
	cout << TEST_TIME << " times DENSE MATRIX multipication: " 
		<< run_dense_mul << " ms" << endl;
	return 0;
}
           

继续阅读