Skip to content

04.请求和响应协议

说明

原文链接:https://build-your-own.org/redis/04_proto

原文作者:James Smith

译者:Cheng

4.1 单个连接中的多个请求

服务器循环

让我们暂时忽略服务器的并发问题,我们通过循环让服务器在单个连接中处理多个请求。

c
while (true) {
    // accept
    struct sockaddr_in client_addr = {};
    socklen_t addrlen = sizeof(client_addr);
    int connfd = accept(fd, (struct sockaddr *)&client_addr, &addrlen);
    if (connfd < 0) {
        continue;   // error
    }
    // 一次只服务一个客户端连接
    while (true) {
        int32_t err = one_request(connfd);
        if (err) {
            break;
        }
    }
    close(connfd);
}

one_request函数负责读取一个请求并返回一个响应。问题在于,如何知道需要读取多少字节?这是应用协议的主要功能。通常协议分为两层结构:

  1. 高层结构:将字节流拆分为消息
  2. 消息内部结构:即反序列化

简单的二进制协议

第一步将字节流拆分为消息。目前,请求和响应消息都只是字符串。

text
┌─────┬──────┬─────┬──────┬────────
│ len │ msg1 │ len │ msg2 │ more...
└─────┴──────┴─────┴──────┴────────
   4B   ...     4B   ...

每条消息都是由一个4字节的小端整数(用于表示请求长度)和一个可变长度的负载组成。这不是Redis所使用的协议,我们后续会讨论其它协议设计。

4.2 解析协议

检查read/write的返回值

read/write返回读取/写入的字节数,错误时返回-1read在EOF(文件结束或连接终止)时返回0。

c
ssize_t read(int fd, void *buf, size_t count);
ssize_t write(int fd, const void *buf, size_t count);

读取消息时,先读取4字节的整数,再读取负载,这是一个错误的实现。演示如下。

c
// 错误示例!
uint32_t n;
char payload[MAX_PAYLOAD];
rv = read(fd, &n, 4);
if (rv != 4) { /* 错误处理 */ }
rv = read(fd, &payload, n);
if (rv != n) { /* 错误处理 */ }

我们在写入消息时,类似的一个错误实现。

c
// 错误示例!
rv = write(fd, &n, 4);
if (rv != 4) { /* 错误处理 */ }
rv = write(fd, &payload, n);
if (rv != n) { /* 错误处理 */ }

这两种方式都是错误的,因为read/write在正常情况下(无错误、无EOF)可能返回少于请求的字节数。为什么会这样呢?我们一会儿来解释。

一个常见的错误是,想当然地认为一次读取操作必然对应于对端的一次写入操作。这是不可能的,因为字节流没有边界。

read_fullwrite_all

为了完整地向TCP套接字读取/写入n个字节数据,我们必须使用循环。

c
static int32_t read_full(int fd, char *buf, size_t n) {
    while (n > 0) {
        ssize_t rv = read(fd, buf, n);
        if (rv <= 0) {
            return -1;  // 错误或者意外EOF
        }
        assert((size_t)rv <= n);
        n -= (size_t)rv;
        buf += rv;
    }
    return 0;
}

static int32_t write_all(int fd, const char *buf, size_t n) {
    while (n > 0) {
        ssize_t rv = write(fd, buf, n);
        if (rv <= 0) {
            return -1;  // 错误
        }
        assert((size_t)rv <= n);
        n -= (size_t)rv;
        buf += rv;
    }
    return 0;
}

无论read返回多少数据,都会累积到缓冲区中。重要的是你有多少数据,而不是单次read返回多少。

解析请求并生成响应

在服务器程序中,使用read_fillwrite_all替代readwrite

c
const size_t k_max_msg = 4096;

static int32_t one_request(int connfd) {
    // 4字节头部
    char rbuf[4 + k_max_msg];
    errno = 0;
    int32_t err = read_full(connfd, rbuf, 4);
    if (err) {
        msg(errno == 0 ? "EOF" : "read() error");
        return err;
    }
    uint32_t len = 0;
    memcpy(&len, rbuf, 4);  // 假设是小端序
    if (len > k_max_msg) {
        msg("too long");
        return -1;
    }
    // 请求体
    err = read_full(connfd, &rbuf[4], len);
    if (err) {
        msg("read() error");
        return err;
    }
    // 处理请求
    printf("client says: %.*s\n", len, &rbuf[4]);
    // 使用相同的协议回复
    const char reply[] = "world";
    char wbuf[4 + sizeof(reply)];
    len = (uint32_t)strlen(reply);
    memcpy(wbuf, &len, 4);
    memcpy(&wbuf[4], reply, len);
    return write_all(connfd, wbuf, 4 + len);
}

errno陷阱

errno在系统调用失败时设置为错误码。然而, 如果系统调用成功errno不会被设置为0,而是保留之前的值。因此,上述代码在调用read_full前设置errno=0用以区分EOF的情况。

只有在调用失败时才读取errno。但某些libc函数无法判断是否失败,除非先清除errno

c
errno = 0;
int val = atoi("0");    // 错误时返回0,但0也可能是有效结果
if (errno) { /* 失败 */ }

errno在C语言中是一个糟糕的设计。Linux内核根本不使用它。系统调用实际返回错误的状态码,是libc中的syscall包装器将错误代码放入errno。更合理的方式如下:

c
int32_t read(int fd, void *buf, size_t size, size_t *actually_read);
// 返回错误码,通过指针输出结果。

4.3 客户端和测试

c
static int32_t query(int fd, const char *text) {
    uint32_t len = (uint32_t)strlen(text);
    if (len > k_max_msg) {
        return -1;
    }
    // 发送请求
    char wbuf[4 + k_max_msg];
    memcpy(wbuf, &len, 4);  // 假设为小端序
    memcpy(&wbuf[4], text, len);
    if (int32_t err = write_all(fd, wbuf, 4 + len)) {
        return err;
    }
    // 4 字节头部
    char rbuf[4 + k_max_msg + 1];
    errno = 0;
    int32_t err = read_full(fd, rbuf, 4);
    if (err) {
        msg(errno == 0 ? "EOF" : "read() error");
        return err;
    }
    memcpy(&len, rbuf, 4);  // 假设为小端序
    if (len > k_max_msg) {
        msg("too long");
        return -1;
    }
    // 响应体
    err = read_full(fd, &rbuf[4], len);
    if (err) {
        msg("read() error");
        return err;
    }
    // 做一些具体处理
    printf("server says: %.*s\n", len, &rbuf[4]);
    return 0;
}

让我们来发送几个命令来测试服务器

c
int main() {
    int fd = socket(AF_INET, SOCK_STREAM, 0);
    if (fd < 0) {
        die("socket()");
    }
    // 代码省略 ...

    // 发送多个请求
    int32_t err = query(fd, "hello1");
    if (err) {
        goto L_DONE;
    }
    err = query(fd, "hello2");
    if (err) {
        goto L_DONE;
    }
L_DONE:
    close(fd);
    return 0;
}

运行服务器和客户端:

bash
$ ./server
client says: hello1
client says: hello2
EOF

$ ./client
server says: world
server says: world

4.4 理解read/write

TCP套接字和磁盘文件

为什么需要read_full?尽管共享相同的read/write API,但读取磁盘文件和读取套接字之间亦有差异。当读取磁盘文件时返回少于请求的数据,这意味着要么是文件结束符(EOF)或错误。但套接字即使在正常条件下也可能返回较少的数据。这可以通过拉取式IO和推送式IO来解释。

网络上的数据是由远程对等方推送的。远程端不需要在发送数据之前等待read调用。内核会分配一个接收缓冲区来存储接收到的数据。read只是将接收缓冲区中可用的任何内容复制到用户空间缓冲区,因为不知道是否还有更多缓冲中的数据。

来自本地文件的数据是从磁盘拉取的。数据总是被认为是“就绪的”,文件大小是已知的。除非是EOF,否则没有理由返回少于请求的数据。

中断的系统调用

为什么需要write_all?通常,write只是将数据追加到内核端缓冲区,实际的网络传输被推迟到操作系统。缓冲区大小有限,所以当缓冲区满时,调用者必须等待它排空才能复制剩余数据。在等待期间,系统调用可能被信号中断,导致write返回部分写入的数据。

read也可能被信号中断,因为如果缓冲区为空,它就必须等待。在这种情况下,读取了0字节,但返回值是-1errnoEINTR。这不是一个错误。这是留给读者的练习:在read_full中处理这种情况。

4.5 关于更多协议

文本与二进制

语气处理二进制数据,为什么不使用更简单、更友好的东西,比如HTTP和JSON?纯文本看起来简单是因为她是人类可读的。但由于实现的复杂,他们对机器来说并不怎么友好。

一个人可读的协议处理字符串,字符串是可变长度的,所以你需要不断检查事物的长度,这既繁琐又容易出错。而二进制协议避免了不必要的字符串,没有什么比memcpy一个结构体更简单的了。

长度前缀与分隔符

本章遵循一个常见模式:

  • 从固定大小的部分开始
  • 可变长度数据紧随其后,长度由固定大小部分指示

解析这样的协议时,你总是知道要读取多少数据。

另一种模式是使用分隔符来指示可变长度部分的结束。要解析分隔协议,需要持续读取直到找到分隔符。但如果有效载荷包含分隔符怎么办?现在你需要转义序列,这增加了更多复杂性。

案例研究:现实世界的协议

HTTP 头是由\r\n分隔的字符串,每个头是由冒号分隔的键值对。URL 可能包含\r\n,所以请求行中的 URL 必须被转义/编码。你可能忘记头部值中不允许\r\n,这已经导致了一些安全漏洞。

GET /index.html HTTP/1.1
Host: example.com
Foo: bar

如果你把编写 HTTP(协议)当作练习,你搞出来的很可能是个漏洞百出的简化版,因为里面涉及太多琐碎工作了,比如各种编码/转义、检查禁用字符等等。HTTP 堪称网络协议设计的反面教材。

真正的 Redis 协议虽然也是人类可读的,但不像 HTTP 那么疯狂。它同时使用分隔符和长度前缀:字符串以长度前缀开头(长度是十进制数字,用换行符分隔),字符串后面也有换行符——但这只是为了可读性。例如:

$5\r\nhello\r\n

你可以尝试实现一个完整的Redis协议当作挑战,毕竟这需要下不少功夫。但别花太多心思在这上面,因为事件循环的优先级更高一点,而且本章代码也复用不了。

4.6 示例代码

04_client.cpp
cpp
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#include <errno.h>
#include <unistd.h>
#include <arpa/inet.h>
#include <sys/socket.h>
#include <netinet/ip.h>


static void msg(const char *msg) {
    fprintf(stderr, "%s\n", msg);
}

static void die(const char *msg) {
    int err = errno;
    fprintf(stderr, "[%d] %s\n", err, msg);
    abort();
}

static int32_t read_full(int fd, char *buf, size_t n) {
    while (n > 0) {
        ssize_t rv = read(fd, buf, n);
        if (rv <= 0) {
            return -1;  // error, or unexpected EOF
        }
        assert((size_t)rv <= n);
        n -= (size_t)rv;
        buf += rv;
    }
    return 0;
}

static int32_t write_all(int fd, const char *buf, size_t n) {
    while (n > 0) {
        ssize_t rv = write(fd, buf, n);
        if (rv <= 0) {
            return -1;  // error
        }
        assert((size_t)rv <= n);
        n -= (size_t)rv;
        buf += rv;
    }
    return 0;
}

const size_t k_max_msg = 4096;

static int32_t query(int fd, const char *text) {
    uint32_t len = (uint32_t)strlen(text);
    if (len > k_max_msg) {
        return -1;
    }

    char wbuf[4 + k_max_msg];
    memcpy(wbuf, &len, 4);  // assume little endian
    memcpy(&wbuf[4], text, len);
    if (int32_t err = write_all(fd, wbuf, 4 + len)) {
        return err;
    }

    // 4 bytes header
    char rbuf[4 + k_max_msg + 1];
    errno = 0;
    int32_t err = read_full(fd, rbuf, 4);
    if (err) {
        msg(errno == 0 ? "EOF" : "read() error");
        return err;
    }

    memcpy(&len, rbuf, 4);  // assume little endian
    if (len > k_max_msg) {
        msg("too long");
        return -1;
    }

    // reply body
    err = read_full(fd, &rbuf[4], len);
    if (err) {
        msg("read() error");
        return err;
    }

    // do something
    printf("server says: %.*s\n", len, &rbuf[4]);
    return 0;
}

int main() {
    int fd = socket(AF_INET, SOCK_STREAM, 0);
    if (fd < 0) {
        die("socket()");
    }

    struct sockaddr_in addr = {};
    addr.sin_family = AF_INET;
    addr.sin_port = ntohs(1234);
    addr.sin_addr.s_addr = ntohl(INADDR_LOOPBACK);  // 127.0.0.1
    int rv = connect(fd, (const struct sockaddr *)&addr, sizeof(addr));
    if (rv) {
        die("connect");
    }

    // multiple requests
    int32_t err = query(fd, "hello1");
    if (err) {
        goto L_DONE;
    }
    err = query(fd, "hello2");
    if (err) {
        goto L_DONE;
    }
    err = query(fd, "hello3");
    if (err) {
        goto L_DONE;
    }

L_DONE:
    close(fd);
    return 0;
}
04_server.cpp
cpp
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#include <errno.h>
#include <unistd.h>
#include <arpa/inet.h>
#include <sys/socket.h>
#include <netinet/ip.h>


static void msg(const char *msg) {
    fprintf(stderr, "%s\n", msg);
}

static void die(const char *msg) {
    int err = errno;
    fprintf(stderr, "[%d] %s\n", err, msg);
    abort();
}

const size_t k_max_msg = 4096;

static int32_t read_full(int fd, char *buf, size_t n) {
    while (n > 0) {
        ssize_t rv = read(fd, buf, n);
        if (rv <= 0) {
            return -1;  // error, or unexpected EOF
        }
        assert((size_t)rv <= n);
        n -= (size_t)rv;
        buf += rv;
    }
    return 0;
}

static int32_t write_all(int fd, const char *buf, size_t n) {
    while (n > 0) {
        ssize_t rv = write(fd, buf, n);
        if (rv <= 0) {
            return -1;  // error
        }
        assert((size_t)rv <= n);
        n -= (size_t)rv;
        buf += rv;
    }
    return 0;
}

static int32_t one_request(int connfd) {
    // 4 bytes header
    char rbuf[4 + k_max_msg];
    errno = 0;
    int32_t err = read_full(connfd, rbuf, 4);
    if (err) {
        msg(errno == 0 ? "EOF" : "read() error");
        return err;
    }

    uint32_t len = 0;
    memcpy(&len, rbuf, 4);  // assume little endian
    if (len > k_max_msg) {
        msg("too long");
        return -1;
    }

    // request body
    err = read_full(connfd, &rbuf[4], len);
    if (err) {
        msg("read() error");
        return err;
    }

    // do something
    fprintf(stderr, "client says: %.*s\n", len, &rbuf[4]);

    // reply using the same protocol
    const char reply[] = "world";
    char wbuf[4 + sizeof(reply)];
    len = (uint32_t)strlen(reply);
    memcpy(wbuf, &len, 4);
    memcpy(&wbuf[4], reply, len);
    return write_all(connfd, wbuf, 4 + len);
}

int main() {
    int fd = socket(AF_INET, SOCK_STREAM, 0);
    if (fd < 0) {
        die("socket()");
    }

    // this is needed for most server applications
    int val = 1;
    setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof(val));

    // bind
    struct sockaddr_in addr = {};
    addr.sin_family = AF_INET;
    addr.sin_port = ntohs(1234);
    addr.sin_addr.s_addr = ntohl(0);    // wildcard address 0.0.0.0
    int rv = bind(fd, (const sockaddr *)&addr, sizeof(addr));
    if (rv) {
        die("bind()");
    }

    // listen
    rv = listen(fd, SOMAXCONN);
    if (rv) {
        die("listen()");
    }

    while (true) {
        // accept
        struct sockaddr_in client_addr = {};
        socklen_t addrlen = sizeof(client_addr);
        int connfd = accept(fd, (struct sockaddr *)&client_addr, &addrlen);
        if (connfd < 0) {
            continue;   // error
        }

        while (true) {
            // here the server only serves one client connection at once
            int32_t err = one_request(connfd);
            if (err) {
                break;
            }
        }
        close(connfd);
    }

    return 0;
}