04.请求和响应协议
4.1 单个连接中的多个请求
服务器循环
让我们暂时忽略服务器的并发问题,我们通过循环让服务器在单个连接中处理多个请求。
while (true) {
// accept
struct sockaddr_in client_addr = {};
socklen_t addrlen = sizeof(client_addr);
int connfd = accept(fd, (struct sockaddr *)&client_addr, &addrlen);
if (connfd < 0) {
continue; // error
}
// 一次只服务一个客户端连接
while (true) {
int32_t err = one_request(connfd);
if (err) {
break;
}
}
close(connfd);
}
one_request
函数负责读取一个请求并返回一个响应。问题在于,如何知道需要读取多少字节?这是应用协议的主要功能。通常协议分为两层结构:
- 高层结构:将字节流拆分为消息
- 消息内部结构:即反序列化
简单的二进制协议
第一步将字节流拆分为消息。目前,请求和响应消息都只是字符串。
┌─────┬──────┬─────┬──────┬────────
│ len │ msg1 │ len │ msg2 │ more...
└─────┴──────┴─────┴──────┴────────
4B ... 4B ...
每条消息都是由一个4字节的小端整数(用于表示请求长度)和一个可变长度的负载组成。这不是Redis所使用的协议,我们后续会讨论其它协议设计。
4.2 解析协议
检查read/write的返回值
read/write
返回读取/写入的字节数,错误时返回-1
。read
在EOF(文件结束或连接终止)时返回0。
ssize_t read(int fd, void *buf, size_t count);
ssize_t write(int fd, const void *buf, size_t count);
读取消息时,先读取4字节的整数,再读取负载,这是一个错误的实现。演示如下。
// 错误示例!
uint32_t n;
char payload[MAX_PAYLOAD];
rv = read(fd, &n, 4);
if (rv != 4) { /* 错误处理 */ }
rv = read(fd, &payload, n);
if (rv != n) { /* 错误处理 */ }
我们在写入消息时,类似的一个错误实现。
// 错误示例!
rv = write(fd, &n, 4);
if (rv != 4) { /* 错误处理 */ }
rv = write(fd, &payload, n);
if (rv != n) { /* 错误处理 */ }
这两种方式都是错误的,因为read/write
在正常情况下(无错误、无EOF)可能返回少于请求的字节数。为什么会这样呢?我们一会儿来解释。
一个常见的错误是,想当然地认为一次读取操作必然对应于对端的一次写入操作。这是不可能的,因为字节流没有边界。
read_full
和write_all
为了完整地向TCP套接字读取/写入n
个字节数据,我们必须使用循环。
static int32_t read_full(int fd, char *buf, size_t n) {
while (n > 0) {
ssize_t rv = read(fd, buf, n);
if (rv <= 0) {
return -1; // 错误或者意外EOF
}
assert((size_t)rv <= n);
n -= (size_t)rv;
buf += rv;
}
return 0;
}
static int32_t write_all(int fd, const char *buf, size_t n) {
while (n > 0) {
ssize_t rv = write(fd, buf, n);
if (rv <= 0) {
return -1; // 错误
}
assert((size_t)rv <= n);
n -= (size_t)rv;
buf += rv;
}
return 0;
}
无论read
返回多少数据,都会累积到缓冲区中。重要的是你有多少数据,而不是单次read
返回多少。
解析请求并生成响应
在服务器程序中,使用read_fill
和write_all
替代read
和write
。
const size_t k_max_msg = 4096;
static int32_t one_request(int connfd) {
// 4字节头部
char rbuf[4 + k_max_msg];
errno = 0;
int32_t err = read_full(connfd, rbuf, 4);
if (err) {
msg(errno == 0 ? "EOF" : "read() error");
return err;
}
uint32_t len = 0;
memcpy(&len, rbuf, 4); // 假设是小端序
if (len > k_max_msg) {
msg("too long");
return -1;
}
// 请求体
err = read_full(connfd, &rbuf[4], len);
if (err) {
msg("read() error");
return err;
}
// 处理请求
printf("client says: %.*s\n", len, &rbuf[4]);
// 使用相同的协议回复
const char reply[] = "world";
char wbuf[4 + sizeof(reply)];
len = (uint32_t)strlen(reply);
memcpy(wbuf, &len, 4);
memcpy(&wbuf[4], reply, len);
return write_all(connfd, wbuf, 4 + len);
}
errno陷阱
errno
在系统调用失败时设置为错误码。然而, 如果系统调用成功,errno
就不会被设置为0
,而是保留之前的值。因此,上述代码在调用read_full
前设置errno=0
用以区分EOF的情况。
只有在调用失败时才读取errno
。但某些libc函数无法判断是否失败,除非先清除errno
。
errno = 0;
int val = atoi("0"); // 错误时返回0,但0也可能是有效结果
if (errno) { /* 失败 */ }
errno
在C语言中是一个糟糕的设计。Linux内核根本不使用它。系统调用实际返回错误的状态码,是libc
中的syscall包装器将错误代码放入errno
。更合理的方式如下:
int32_t read(int fd, void *buf, size_t size, size_t *actually_read);
// 返回错误码,通过指针输出结果。
4.3 客户端和测试
static int32_t query(int fd, const char *text) {
uint32_t len = (uint32_t)strlen(text);
if (len > k_max_msg) {
return -1;
}
// 发送请求
char wbuf[4 + k_max_msg];
memcpy(wbuf, &len, 4); // 假设为小端序
memcpy(&wbuf[4], text, len);
if (int32_t err = write_all(fd, wbuf, 4 + len)) {
return err;
}
// 4 字节头部
char rbuf[4 + k_max_msg + 1];
errno = 0;
int32_t err = read_full(fd, rbuf, 4);
if (err) {
msg(errno == 0 ? "EOF" : "read() error");
return err;
}
memcpy(&len, rbuf, 4); // 假设为小端序
if (len > k_max_msg) {
msg("too long");
return -1;
}
// 响应体
err = read_full(fd, &rbuf[4], len);
if (err) {
msg("read() error");
return err;
}
// 做一些具体处理
printf("server says: %.*s\n", len, &rbuf[4]);
return 0;
}
让我们来发送几个命令来测试服务器
int main() {
int fd = socket(AF_INET, SOCK_STREAM, 0);
if (fd < 0) {
die("socket()");
}
// 代码省略 ...
// 发送多个请求
int32_t err = query(fd, "hello1");
if (err) {
goto L_DONE;
}
err = query(fd, "hello2");
if (err) {
goto L_DONE;
}
L_DONE:
close(fd);
return 0;
}
运行服务器和客户端:
$ ./server
client says: hello1
client says: hello2
EOF
$ ./client
server says: world
server says: world
4.4 理解read/write
TCP套接字和磁盘文件
为什么需要read_full
?尽管共享相同的read/write
API,但读取磁盘文件和读取套接字之间亦有差异。当读取磁盘文件时返回少于请求的数据,这意味着要么是文件结束符(EOF)或错误。但套接字即使在正常条件下也可能返回较少的数据。这可以通过拉取式IO
和推送式IO
来解释。
网络上的数据是由远程对等方推送的。远程端不需要在发送数据之前等待read
调用。内核会分配一个接收缓冲区来存储接收到的数据。read
只是将接收缓冲区中可用的任何内容复制到用户空间缓冲区,因为不知道是否还有更多缓冲中的数据。
来自本地文件的数据是从磁盘拉取的。数据总是被认为是“就绪的”,文件大小是已知的。除非是EOF,否则没有理由返回少于请求的数据。
中断的系统调用
为什么需要write_all
?通常,write
只是将数据追加到内核端缓冲区,实际的网络传输被推迟到操作系统。缓冲区大小有限,所以当缓冲区满时,调用者必须等待它排空才能复制剩余数据。在等待期间,系统调用可能被信号中断,导致write
返回部分写入的数据。
read
也可能被信号中断,因为如果缓冲区为空,它就必须等待。在这种情况下,读取了0
字节,但返回值是-1
,errno
是EINTR
。这不是一个错误。这是留给读者的练习:在read_full
中处理这种情况。
4.5 关于更多协议
文本与二进制
语气处理二进制数据,为什么不使用更简单、更友好的东西,比如HTTP和JSON?纯文本看起来简单是因为她是人类可读的。但由于实现的复杂,他们对机器来说并不怎么友好。
一个人可读的协议处理字符串,字符串是可变长度的,所以你需要不断检查事物的长度,这既繁琐又容易出错。而二进制协议避免了不必要的字符串,没有什么比memcpy
一个结构体更简单的了。
长度前缀与分隔符
本章遵循一个常见模式:
- 从固定大小的部分开始
- 可变长度数据紧随其后,长度由固定大小部分指示
解析这样的协议时,你总是知道要读取多少数据。
另一种模式是使用分隔符来指示可变长度部分的结束。要解析分隔协议,需要持续读取直到找到分隔符。但如果有效载荷包含分隔符怎么办?现在你需要转义序列,这增加了更多复杂性。
案例研究:现实世界的协议
HTTP 头是由\r\n
分隔的字符串,每个头是由冒号分隔的键值对。URL 可能包含\r\n
,所以请求行中的 URL 必须被转义/编码。你可能忘记头部值中不允许\r\n
,这已经导致了一些安全漏洞。
GET /index.html HTTP/1.1
Host: example.com
Foo: bar
如果你把编写 HTTP(协议)当作练习,你搞出来的很可能是个漏洞百出的简化版,因为里面涉及太多琐碎工作了,比如各种编码/转义、检查禁用字符等等。HTTP 堪称网络协议设计的反面教材。
真正的 Redis 协议虽然也是人类可读的,但不像 HTTP 那么疯狂。它同时使用分隔符和长度前缀:字符串以长度前缀开头(长度是十进制数字,用换行符分隔),字符串后面也有换行符——但这只是为了可读性。例如:
$5\r\nhello\r\n
你可以尝试实现一个完整的Redis协议当作挑战,毕竟这需要下不少功夫。但别花太多心思在这上面,因为事件循环的优先级更高一点,而且本章代码也复用不了。
4.6 示例代码
04_client.cpp
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#include <errno.h>
#include <unistd.h>
#include <arpa/inet.h>
#include <sys/socket.h>
#include <netinet/ip.h>
static void msg(const char *msg) {
fprintf(stderr, "%s\n", msg);
}
static void die(const char *msg) {
int err = errno;
fprintf(stderr, "[%d] %s\n", err, msg);
abort();
}
static int32_t read_full(int fd, char *buf, size_t n) {
while (n > 0) {
ssize_t rv = read(fd, buf, n);
if (rv <= 0) {
return -1; // error, or unexpected EOF
}
assert((size_t)rv <= n);
n -= (size_t)rv;
buf += rv;
}
return 0;
}
static int32_t write_all(int fd, const char *buf, size_t n) {
while (n > 0) {
ssize_t rv = write(fd, buf, n);
if (rv <= 0) {
return -1; // error
}
assert((size_t)rv <= n);
n -= (size_t)rv;
buf += rv;
}
return 0;
}
const size_t k_max_msg = 4096;
static int32_t query(int fd, const char *text) {
uint32_t len = (uint32_t)strlen(text);
if (len > k_max_msg) {
return -1;
}
char wbuf[4 + k_max_msg];
memcpy(wbuf, &len, 4); // assume little endian
memcpy(&wbuf[4], text, len);
if (int32_t err = write_all(fd, wbuf, 4 + len)) {
return err;
}
// 4 bytes header
char rbuf[4 + k_max_msg + 1];
errno = 0;
int32_t err = read_full(fd, rbuf, 4);
if (err) {
msg(errno == 0 ? "EOF" : "read() error");
return err;
}
memcpy(&len, rbuf, 4); // assume little endian
if (len > k_max_msg) {
msg("too long");
return -1;
}
// reply body
err = read_full(fd, &rbuf[4], len);
if (err) {
msg("read() error");
return err;
}
// do something
printf("server says: %.*s\n", len, &rbuf[4]);
return 0;
}
int main() {
int fd = socket(AF_INET, SOCK_STREAM, 0);
if (fd < 0) {
die("socket()");
}
struct sockaddr_in addr = {};
addr.sin_family = AF_INET;
addr.sin_port = ntohs(1234);
addr.sin_addr.s_addr = ntohl(INADDR_LOOPBACK); // 127.0.0.1
int rv = connect(fd, (const struct sockaddr *)&addr, sizeof(addr));
if (rv) {
die("connect");
}
// multiple requests
int32_t err = query(fd, "hello1");
if (err) {
goto L_DONE;
}
err = query(fd, "hello2");
if (err) {
goto L_DONE;
}
err = query(fd, "hello3");
if (err) {
goto L_DONE;
}
L_DONE:
close(fd);
return 0;
}
04_server.cpp
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#include <errno.h>
#include <unistd.h>
#include <arpa/inet.h>
#include <sys/socket.h>
#include <netinet/ip.h>
static void msg(const char *msg) {
fprintf(stderr, "%s\n", msg);
}
static void die(const char *msg) {
int err = errno;
fprintf(stderr, "[%d] %s\n", err, msg);
abort();
}
const size_t k_max_msg = 4096;
static int32_t read_full(int fd, char *buf, size_t n) {
while (n > 0) {
ssize_t rv = read(fd, buf, n);
if (rv <= 0) {
return -1; // error, or unexpected EOF
}
assert((size_t)rv <= n);
n -= (size_t)rv;
buf += rv;
}
return 0;
}
static int32_t write_all(int fd, const char *buf, size_t n) {
while (n > 0) {
ssize_t rv = write(fd, buf, n);
if (rv <= 0) {
return -1; // error
}
assert((size_t)rv <= n);
n -= (size_t)rv;
buf += rv;
}
return 0;
}
static int32_t one_request(int connfd) {
// 4 bytes header
char rbuf[4 + k_max_msg];
errno = 0;
int32_t err = read_full(connfd, rbuf, 4);
if (err) {
msg(errno == 0 ? "EOF" : "read() error");
return err;
}
uint32_t len = 0;
memcpy(&len, rbuf, 4); // assume little endian
if (len > k_max_msg) {
msg("too long");
return -1;
}
// request body
err = read_full(connfd, &rbuf[4], len);
if (err) {
msg("read() error");
return err;
}
// do something
fprintf(stderr, "client says: %.*s\n", len, &rbuf[4]);
// reply using the same protocol
const char reply[] = "world";
char wbuf[4 + sizeof(reply)];
len = (uint32_t)strlen(reply);
memcpy(wbuf, &len, 4);
memcpy(&wbuf[4], reply, len);
return write_all(connfd, wbuf, 4 + len);
}
int main() {
int fd = socket(AF_INET, SOCK_STREAM, 0);
if (fd < 0) {
die("socket()");
}
// this is needed for most server applications
int val = 1;
setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof(val));
// bind
struct sockaddr_in addr = {};
addr.sin_family = AF_INET;
addr.sin_port = ntohs(1234);
addr.sin_addr.s_addr = ntohl(0); // wildcard address 0.0.0.0
int rv = bind(fd, (const sockaddr *)&addr, sizeof(addr));
if (rv) {
die("bind()");
}
// listen
rv = listen(fd, SOMAXCONN);
if (rv) {
die("listen()");
}
while (true) {
// accept
struct sockaddr_in client_addr = {};
socklen_t addrlen = sizeof(client_addr);
int connfd = accept(fd, (struct sockaddr *)&client_addr, &addrlen);
if (connfd < 0) {
continue; // error
}
while (true) {
// here the server only serves one client connection at once
int32_t err = one_request(connfd);
if (err) {
break;
}
}
close(connfd);
}
return 0;
}