天天看点

IOCP windows套接字使用完成端口

本文的代码源自《游戏服务端IOCP模型,自己封装的一个类,3行代码搞定服务端》,我改进过了,希望作者不要说我侵权,我声明这段代码是作者的劳动结晶,我只不过是在此基础上进行了些修改和调试。

windows里有如同Linux中的epoll一般强大的套接字管理功能,即socket编程模型。

我们面对服务器端编程时,往往希望一台主机能同时承接成千上万个客户端连接,只要我们的CPU和内存足够处理业务即可。但对于socket,如果使用select管理,在windows里有最多管理64个套接字的上限,毕竟都是依靠轮询来反馈事件的。如果要管理上百个套接字,我们就需要考虑使用IOCP(完成端口)模型了,见《Windows网络编程》5.2.6 完成端口模型一节的内容。

在经历了2天各种百度学习的情况下,我发现网上对于这个完成端口描述大多都是照本宣科,而且逻辑不完整,同样,书中也有不完整的地方,所以我总结此文,并附带可用的代码供大家参考学习,其中如果有不对的地方,望留言指正!

直接上代码,再说明用法,希望理解完成端口逻辑的同学可以看书或百度:

#pragma once
#include <WinSock2.h>
#include <afxmt.h>
#include <afxtempl.h>

#define ULONG_PTR ULONG
#define PULONG_PTR ULONG*

#define  BUFFER_SIZE 1024
#define  SOCK_TIMEOUT_SECONDS 60

class Iocp;

typedef enum {
	OP_READ   = 1,
	OP_WRITE  = 2,
	OP_ACCEPT = 3,
	OP_CLOSE   = 100,
	OP_DO_WORK = 101
} SOCKET_STATE;

typedef struct
{
	OVERLAPPED oOverlapped;
	WSABUF wsBuffer;
	CHAR szBuffer[BUFFER_SIZE];
	DWORD dSend;
	DWORD dRecv;
	SOCKET_STATE sState;
} PER_IO_DATA, *LPPER_IO_DATA;

/*传送给处理函数的参数*/
typedef struct
{
	SOCKET sSocket; // 客户端socket描述符
	int    index;  // 序号,用于索引
	CHAR   key[32]; // ip:port
	CHAR   szClientIP[24]; // 客户端IP字符串
	UINT   uiClientPort;  // 客户端端口
	time_t lastReceiveTime; // 最后接收时间
	time_t connectedTime; // 创建链接的时间(如果超过这个时间还没有收到有效的ID,那么关闭)
	LPPER_IO_DATA lpIOData; // 释放内存用
	Iocp *pIocp; // ServerScanThread要用
	CMutex *lpMutex;
} IOCPClient, *LPIOCPClient;

typedef struct 
{
	int index; // 同IOCPClient的index
	CMap<CString, LPCTSTR, IOCPClient*, IOCPClient*> sockets;
} STRU_MAP_ClientSockets;

typedef void (*ReadProc)(LPIOCPClient lpData, LPPER_IO_DATA lpPER_IO_DATA);
typedef VOID (*ScanProc)(LPIOCPClient lpClientSocket);

class Iocp
{
public:
	Iocp(const CHAR *host, UINT port);
	~Iocp(void);
	
	VOID SetThreadNums();
	UINT GetThreadNums();
	VOID SetPort(UINT port);
	UINT GetPort();
	BOOL ListenEx(UINT backlog);
	VOID Close();
	VOID Iocp::CreateScanThreads();
	static VOID ServerWorkThread(VOID *_this);
	static VOID ServerScanThread(VOID *s);
	static VOID FreeClientSocket(Iocp *lpIocp, LPIOCPClient lpClientSocket);
	static int Send(SOCKET sockfd, const char *buff, const unsigned int size);
	static VOID SetClientSocketCountText(unsigned int count);
	static VOID OutPutLog(const char *szFormat, ...);
	VOID SetReadFunc(VOID *lprFun);
	VOID SetScanFunc(VOID *lprFun);
	
	int m_ThreadNums; // 线程数量,用于将socket分割到多个区域,扫描时每次只扫描一个区域
	int m_AcceptClientIndex; // 接受连接的socket的序号,跟m_ThreadNums取余
	STRU_MAP_ClientSockets *m_Sockets; // 因为需要根据线程数动态分配内存,所以不能是静态变量
	unsigned int m_SocketCount; // 已连接客户端的数量
	ReadProc m_ReadFun; // 读数据回调函数
	ScanProc m_ScanFun; // 扫描socket回调函数
	HANDLE m_cpHandle;  // IO完成端口句柄
	
	// 扩展的接受连接,放在线程里了
	static VOID AcceptEx(VOID  *_this);
	
	// 监听套接字,即服务端套接字
	SOCKET  m_ListenSocketID;
};

           
#include "stdafx.h"
#include "Iocp.h"
#include <stdlib.h>
#include <process.h>
#include "resource.h"

#pragma comment(lib, "ws2_32.lib")

extern void DoRxTxWork(LPIOCPClient lpClientSocket);

Iocp::Iocp(const CHAR *host, UINT port):
	m_ListenSocketID(INVALID_SOCKET),
	m_AcceptClientIndex(0)
{
	SetClientSocketCountText((m_SocketCount = 0));

	WSADATA wsaData;
	DWORD dwRet = WSAStartup( 0x0202, &wsaData );
	if (0 != dwRet )
	{
		WSACleanup();
		throw 1;
	}

	SOCKADDR_IN sockAddr;
	memset( &sockAddr, 0, sizeof(SOCKADDR_IN) ) ;
	sockAddr.sin_family = AF_INET;
	sockAddr.sin_addr.s_addr = inet_addr(host);
	sockAddr.sin_port = htons(port);

	/*创建监听套接字*/
	m_ListenSocketID = WSASocket( AF_INET, SOCK_STREAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED );
	if ( m_ListenSocketID == INVALID_SOCKET )
	{
		throw 1;
	}

	/*设置套接字选项*/
	CHAR opt = 1;
	BOOL ret = setsockopt( m_ListenSocketID , SOL_SOCKET , SO_REUSEADDR , (const CHAR * )&opt , sizeof(opt) );
	if ( ret != 0 )
	{
		throw 1 ;
	}

	/*绑定套接字*/
	if (SOCKET_ERROR == bind(m_ListenSocketID, (const struct sockaddr *)&sockAddr, sizeof(struct sockaddr)))
	{
		throw 1 ;
	}

	/*创建完成端口*/
	m_cpHandle  = CreateIoCompletionPort( INVALID_HANDLE_VALUE, NULL, 0, 0 );
	if ( m_cpHandle == NULL )
	{
		throw 1 ;
	}

	SYSTEM_INFO mySysInfo;
	GetSystemInfo( &mySysInfo );
	m_ThreadNums = (int)mySysInfo.dwNumberOfProcessors * 2;
	//m_ThreadNums = 1;
	m_Sockets = new STRU_MAP_ClientSockets[m_ThreadNums];
	for ( int i = 0; i < m_ThreadNums; i++ )
	{
		m_Sockets[i].index = i;
		_beginthread(Iocp::ServerWorkThread,  0,  (VOID *)this);
	}
	TRACE("工作线程准备完成(%d个)\n", m_ThreadNums);
	OutPutLog("工作线程准备完成(%d个)\n", m_ThreadNums);
}

Iocp::~Iocp(void)
{
	WSACleanup();
}

VOID Iocp::AcceptEx(VOID  *_this)
{
	SOCKET acSocket;
	DWORD dwRecvBytes;
	Iocp * pIocp = (Iocp *)_this;
	SOCKADDR_IN sAddr;
	INT uiClientSize = sizeof(sAddr);

	TRACE("服务器已就绪, 套接字=%u ...\n", pIocp->m_ListenSocketID);
	OutPutLog("服务器已就绪, 套接字=%u ...\n", pIocp->m_ListenSocketID);
	while (TRUE)
	{
		acSocket = WSAAccept( pIocp->m_ListenSocketID, (SOCKADDR *)&sAddr, &uiClientSize, NULL, 0 );
		if ( acSocket == SOCKET_ERROR )
		{
			TRACE("接受连接发生错误: %d\n", WSAGetLastError());
			return;
		}

		LPIOCPClient lpClientSocket = (LPIOCPClient)malloc(sizeof(IOCPClient));
		if ( NULL == lpClientSocket )
		{
			TRACE("Error while malloc lpClientSocket\n");
			return;
		}
		memset(lpClientSocket, 0, sizeof(IOCPClient));

		/*这里停止监听会有问题*/
		LPPER_IO_DATA lpIOData = (LPPER_IO_DATA )malloc(sizeof(PER_IO_DATA));
		if ( lpIOData == NULL )
		{
			TRACE("Error while malloc lpIOData\n");
			return;
		}
		memset(lpIOData, 0, sizeof(PER_IO_DATA));
		
		lpClientSocket->connectedTime = lpClientSocket->lastReceiveTime = time(NULL);
		lpClientSocket->lpIOData = lpIOData; // 释放内存用
		lpClientSocket->sSocket = acSocket;
		lpClientSocket->pIocp = pIocp;
		strcpy(lpClientSocket->szClientIP, inet_ntoa(sAddr.sin_addr));
		lpClientSocket->uiClientPort = sAddr.sin_port;
		_snprintf(lpClientSocket->key, sizeof lpClientSocket->key, "%s:%d", lpClientSocket->szClientIP, lpClientSocket->uiClientPort);
		lpClientSocket->lpMutex = new CMutex(FALSE, lpClientSocket->key);
		if (CreateIoCompletionPort( (HANDLE)acSocket, pIocp->m_cpHandle, (ULONG_PTR)lpClientSocket, 0 ) == NULL)
		{
			TRACE("Error while CreateIoCompletionPort\n");
			return;
		}
		TRACE("客户端已连接:%s:%u\n", lpClientSocket->szClientIP, lpClientSocket->uiClientPort);
		OutPutLog("客户端已连接:%s:%u\n", lpClientSocket->szClientIP, lpClientSocket->uiClientPort);
		
		// 投递线程事件
		lpIOData->dSend = 0;
		lpIOData->dRecv = 0;
		lpIOData->wsBuffer.len = BUFFER_SIZE - 1;
		lpIOData->wsBuffer.buf = lpIOData->szBuffer;
		lpIOData->sState = OP_READ;

		DWORD flags = 0;
		if (WSARecv(acSocket, &(lpIOData->wsBuffer), 1, &dwRecvBytes, &flags, &(lpIOData->oOverlapped), NULL) == SOCKET_ERROR)
		{
			if (WSAGetLastError() != ERROR_IO_PENDING )
			{
				TRACE("Error ERROR_IO_PENDING\n");
				return;
			}
			else
			{
				
				// 客户端按接受连接的顺序依次放入4个线程进行扫描处理
				pIocp->m_AcceptClientIndex = (pIocp->m_AcceptClientIndex + 1) % pIocp->m_ThreadNums;
				lpClientSocket->index = pIocp->m_AcceptClientIndex;
				pIocp->m_Sockets[lpClientSocket->index].sockets[lpClientSocket->key] = lpClientSocket;
				SetClientSocketCountText(++pIocp->m_SocketCount);
				TRACE("客户端异步读取已完成,等待读取数据...\n");
				OutPutLog("客户端异步读取已完成,等待读取数据...\n");
			}
		}
	}
}

BOOL Iocp::ListenEx(UINT backlog)
{
	if (SOCKET_ERROR == listen(m_ListenSocketID, backlog))
	{
		return FALSE;
	}
	/*创建监听线程*/
	if (-1 == _beginthread(Iocp::AcceptEx, 0, (VOID *)this))
	{
		return FALSE;
	}
	return TRUE;
}

VOID Iocp:: ServerWorkThread( VOID * _this )
{
	Iocp * lpIocp = (Iocp *)_this;
	HANDLE hPlePort  = (HANDLE)lpIocp->m_cpHandle;
	DWORD dwBytes;
	LPIOCPClient lpClientSocket = NULL;
	LPPER_IO_DATA lpIOData = NULL;
	LPOVERLAPPED lpOverlapped = NULL;
	DWORD sendBytes = 0;
	DWORD recvBytes = 0;
	DWORD dwFlag = 0;
	while (TRUE)
	{
		if (0 == GetQueuedCompletionStatus( hPlePort, &dwBytes, (PULONG_PTR)&lpClientSocket, &lpOverlapped, INFINITE ))
		{
			FreeClientSocket(lpIocp, lpClientSocket);
			continue ;
		}
		lpIOData = (LPPER_IO_DATA)CONTAINING_RECORD(lpOverlapped, PER_IO_DATA, oOverlapped);
		if (0 == dwBytes && (lpIOData->sState == OP_READ || lpIOData->sState == OP_WRITE))
		{
			TRACE("客户端断开了连接:%s\n", lpClientSocket->key);
			OutPutLog("客户端断开了连接:%s\n", lpClientSocket->key);
			closesocket(lpClientSocket->sSocket);
			FreeClientSocket(lpIocp, lpClientSocket);
			continue;
		}

		switch (lpIOData->sState) {
		case OP_READ:
			lpIOData->dRecv = dwBytes;
			lpClientSocket->lastReceiveTime = time(NULL);

			lpIocp->m_ReadFun(lpClientSocket, lpIOData);

			lpIOData->dRecv = 0;
			ZeroMemory( &(lpIOData->oOverlapped), sizeof( OVERLAPPED ) );
			lpIOData->wsBuffer.len = BUFFER_SIZE - 1;
			lpIOData->wsBuffer.buf = lpIOData->szBuffer;
			lpIOData->sState = OP_READ;
			if ( WSARecv( lpClientSocket->sSocket, &(lpIOData->wsBuffer), 1, &recvBytes, &dwFlag, &(lpIOData->oOverlapped), NULL ) == SOCKET_ERROR )
			{
				if ( WSAGetLastError() != ERROR_IO_PENDING )
				{
					return;
				}
			}
			break;
		case OP_WRITE:
			// 什么也不用做
			break;
		case OP_DO_WORK:
			lpIocp->m_ScanFun(lpClientSocket);
			break;
		case OP_CLOSE:
			TRACE("主动断开长期无响应的客户端:%s\n", lpClientSocket->key);
			OutPutLog("主动断开长期无响应的客户端:%s\n", lpClientSocket->key);
			// 这里不能直接释放内存,因为还会触发一次GetQueuedCompletionStatus返回0,在返回0时释放内存
			closesocket(lpClientSocket->sSocket); 
			break;
		default:
			break;
		}

	}
}

VOID Iocp::FreeClientSocket(Iocp *lpIocp, LPIOCPClient lpClientSocket)
{
	if (NULL == lpIocp || NULL == lpClientSocket) {
		return;
	}
	lpIocp->m_Sockets[lpClientSocket->index].sockets.RemoveKey(lpClientSocket->key);
	SetClientSocketCountText(--lpIocp->m_SocketCount);
	free(lpClientSocket->lpIOData);
	free(lpClientSocket);
	TRACE("内存已经释放!\n");
}

VOID Iocp::SetReadFunc(VOID *lprFun)
{
	m_ReadFun  = (ReadProc)lprFun;
}

VOID Iocp::SetScanFunc(VOID *lprFun)
{
	m_ScanFun  = (ScanProc)lprFun;
	CreateScanThreads();
}

VOID Iocp::CreateScanThreads()
{
	STRU_MAP_ClientSockets *sock;
	for (int i = 0; i < m_ThreadNums; i++) {
		sock = &m_Sockets[i];
		_beginthread(Iocp::ServerScanThread,  0,  (VOID *)sock);
	}
}

VOID Iocp::ServerScanThread(VOID *s)
{
	static PER_IO_DATA IOData;
	POSITION pos;
	CString key;
	IOCPClient *lpClientSocket;
	STRU_MAP_ClientSockets *mapSock = (STRU_MAP_ClientSockets*)s;
	int index = mapSock->index;
	int doCount = 0;
	CMap<CString, LPCTSTR, IOCPClient*, IOCPClient*> *serverSockets = &mapSock->sockets;
	while (1) {
		Sleep(5000);
		//OutPutLog("序号[%d]定时器开始处理...", index);
		doCount = 0;
		pos = serverSockets->GetStartPosition();
		while (pos) {
			doCount++;
			serverSockets->GetNextAssoc(pos, key, lpClientSocket);

			memset(&IOData, 0, sizeof(PER_IO_DATA));
			IOData.sState = OP_DO_WORK;
			PostQueuedCompletionStatus(lpClientSocket->pIocp->m_cpHandle, 0, (ULONG_PTR)lpClientSocket, &IOData.oOverlapped);
		}
		//OutPutLog("序号[%d]定时器处理了%d个客户端", index, doCount);
	}
}

void Iocp::SetClientSocketCountText(unsigned int count)
{
	CString countStr;
	countStr.Format("客户端数量: %u", count);

	CWnd *pWnd = AfxGetMainWnd();
	HWND hHwnd = pWnd->m_hWnd;
	::SetDlgItemText(hHwnd, IDC_CLIENT_COUNT, countStr);
}

void Iocp::OutPutLog(const char *szFormat, ...)
{
	static char szLogBuffer[1024];
	
	SYSTEMTIME curTime;
	GetLocalTime(&curTime);
	CString strTime;
	strTime.Format(_T("[%04d-%02d-%02d %02d:%02d:%02d] "),
		curTime.wYear,curTime.wMonth,curTime.wDay,
		curTime.wHour,curTime.wMinute,curTime.wSecond);
	strTime += szFormat;
	
	va_list pArgList;
	va_start(pArgList, szFormat);
	int len = _vsntprintf(szLogBuffer, sizeof szLogBuffer-2, strTime, pArgList);
	va_end(pArgList);
	if (szLogBuffer[len-1] == '\n') {
		if (szLogBuffer[len-2] != '\r') {
			szLogBuffer[len-1] = '\r';
			szLogBuffer[len] = '\n';
			szLogBuffer[len+1] = '\0';
		}
	} else {
		szLogBuffer[len] = '\r';
		szLogBuffer[len+1] = '\n';
		szLogBuffer[len+2] = '\0';
	}
	
	
	CWnd *pWnd =  AfxGetMainWnd();
	CEdit *pEdit = (CEdit*)pWnd->GetDlgItem(IDC_OUTLOG_EDIT);
	if (NULL == pEdit) return;
	
	int iTextLen = pEdit->GetWindowTextLength();
	pEdit->SetRedraw(FALSE);
	pEdit->SetReadOnly(FALSE);
	pEdit->SetSel(iTextLen, iTextLen, TRUE);
	pEdit->ReplaceSel(szLogBuffer);     // 这个函数还是在光标的位置书写
	int lineCount = pEdit->GetLineCount(); // m_prlog是绑定CEDIT控件的对象
    if(lineCount > 100) // 如果输出日志行太多,则删第一行
	{
		pEdit->GetWindowText(szLogBuffer,1024 - 1);//只取前100个字符
		CString tmp(szLogBuffer);
		int it1 = tmp.Find("\r\n") + 2; // 查找第一行的回车换行位置
		pEdit->SetSel(0, it1); // 选择要删除的首行
		pEdit->ReplaceSel(""); // 用空串替换掉首行
	}
	pEdit->LineScroll(lineCount);   //可用于水平滚动所有行最后一个字符,这只是设置edit进行滚动
	pEdit->SetReadOnly(TRUE);
	pEdit->SetRedraw(TRUE);
}

int Iocp::Send(SOCKET sockfd, const char *buff, const unsigned int size)
{
	static PER_IO_DATA PerIOData;

	memset(&PerIOData, 0, sizeof(PER_IO_DATA));
	PerIOData.sState = OP_WRITE;
	PerIOData.wsBuffer.len = size;
	PerIOData.wsBuffer.buf = (char *)buff;
	DWORD byteSend = 0;
	int ErrorCode;
	int result = WSASend(sockfd, &PerIOData.wsBuffer, 1, &byteSend, 0, &PerIOData.oOverlapped, NULL);
	if (SOCKET_ERROR == result && ERROR_IO_PENDING != (ErrorCode = WSAGetLastError())) {
		TRACE("发送数据出错,错误码: %d\n", ErrorCode);
	} else {
		TRACE("成功发送数据: %d字节,返回值:%d\n", byteSend, result);
	}
	return result;
}
           
// 回调1:客户端的发送的数据会在这个函数通知
void OnRead(LPIOCPClient lpClientSocket, LPPER_IO_DATA lpIOData)
{
	if (NULL == lpClientSocket || NULL == lpIOData) {
		return;
	}
    int RxCount = (int) lpIOData->dRecv;
    char *RxBuff  = lpIOData->szBuffer;
	RxBuff[RxCount] = '\0'; // 务必保证接收时留1个字节余量给这个结尾的0
	Iocp::OutPutLog("%s:%d: %s\n", lpClientSocket->szClientIP, lpClientSocket->uiClientPort, RxBuff);
	Iocp::Send(lpClientSocket->sSocket, RxBuff, RxCount);
}

// 回调2:扫描套接字,目的是关闭闲置套接字,或定时发送心跳包(业务逻辑上要求对方回答)
// 其实这个函数可以直接关闭套接字,只不过通过单独的CLOSE通知会对业务处理更灵活和方便
// 如果你在业务中体会不到,可以直接调用closesocket即可。
VOID OnScan(LPIOCPClient lpClientSocket)
{
	static PER_IO_DATA IOData;
	if (NULL == lpClientSocket) {
		return;
	}
	if (time(NULL) - lpClientSocket->lastReceiveTime > SOCK_TIMEOUT_SECONDS) {
		memset(&IOData, 0, sizeof(PER_IO_DATA));
		IOData.sState = OP_CLOSE;
		PostQueuedCompletionStatus(lpClientSocket->pIocp->m_cpHandle, 0, (ULONG_PTR)lpClientSocket, &IOData.oOverlapped);
	}
}
// 调用的位置,在MFC项目里
void CIOCPSocketDlg::OnRunServer() 
{
	static Iocp *g_IocpServer = NULL;
	if (NULL == g_IocpServer) {
		g_IocpServer = new Iocp("0.0.0.0",  8888);
		g_IocpServer->SetReadFunc(OnRead); // 回调1,读取套接字发来的内容
		g_IocpServer->SetScanFunc(OnScan); // 回调2,定期扫描套接字,可能是业务逻辑要求发心跳包,这个步骤可以免去
		g_IocpServer->ListenEx(10);
	}
}
           

代码在MFC项目里,所以还有设置窗体内容的逻辑,大家修改成自己的即可。

补充代码逻辑(书上没讲到的):如果要主动关闭套接字,直接调用closesocket函数即可,因为调用此函数会导致GetQueuedCompletionStatus函数返回0,在返回0的逻辑里释放两个malloc的变量即可。而如果是客户端断开了,GetQueuedCompletionStatus返回的不是0,但满足0 == dwBytes并且是读状态,在这里则除了调用closesocket以外,还要释放malloc的变量(书上讲到了)。

另外:对于同一个套接字,应该不会同时在多个线程里出发读取完成操作,但是很可能在多个线程里出发读取和扫描通知(OP_DO_WORK),所以在业务中,如有必要,要考虑给每一个客户端一个mutex,加锁处理。

使用时,只要实现read,scan两个方法,如果确定不要scan(无法处理从网络中消失的客户端,比如客户端突然死机或断网,或服务器断网一段时间,在此期间客户端主动断开了),那么就不挂载scan函数即可,看SetScanFunc的实现里有新建线程的操作哦:

VOID Iocp::SetScanFunc(VOID *lprFun)
{
	m_ScanFun  = (ScanProc)lprFun;
	CreateScanThreads();
}
           

继续阅读