server header validation, handle ping message

main
pantonshire 2 weeks ago
parent 001414baa4
commit 6f909a752b

@ -10,9 +10,11 @@
#define OWD_DATA_LEN(HDR) (offsetof(__typeof__(HDR), _data_end))
enum owd_msg_type {
OWD_MSG_ERR,
OWD_MSG_SYNC_REQUEST,
OWD_MSG_SYNC_RESPONSE,
OWD_MSG_ERR = 0,
OWD_MSG_PING = 1,
OWD_MSG_PONG = 2,
OWD_MSG_SYNC_REQUEST = 3,
OWD_MSG_SYNC_RESPONSE = 4,
};
enum owd_error_code {
@ -30,6 +32,7 @@ struct owdhdr {
uint8_t version;
uint8_t msg_type;
uint16_t length;
uint32_t tid;
};
struct owd_msg_err {
@ -44,9 +47,25 @@ struct owd_msg_sync_response {
OWD_DECL_DATA_END;
};
static inline void owd_hdr_populate(struct owdhdr *hdr, uint8_t msg_type, uint16_t length) {
static inline void owd_hdr_populate(
struct owdhdr *hdr,
uint8_t msg_type,
uint16_t length,
uint32_t tid)
{
hdr->magic = htonl(OWD_MAGIC);
hdr->version = OWD_VERSION;
hdr->msg_type = msg_type;
hdr->length = htons(length);
hdr->tid = htonl(tid);
}
static inline int owd_hdr_validate(struct owdhdr *hdr, ssize_t len) {
if (len < sizeof(struct owdhdr))
return 1;
if (hdr->magic != ntohl(OWD_MAGIC))
return 1;
if (hdr->version != OWD_VERSION)
return 1;
return 0;
}

@ -8,6 +8,7 @@
#include <netinet/in.h>
#include "server.h"
#include "io.h"
#include "util.h"
#include "protocol.h"
@ -19,16 +20,13 @@ struct rx_info {
const void *name;
socklen_t namelen;
struct timeval tv;
uint32_t tid;
uint8_t tv_set : 1;
};
static int setsockopt_int(int fd, int level, int option, int val) {
return setsockopt(fd, level, option, &val, sizeof(val));
}
static int server_init(const struct server_conf *conf, struct server_ctx *ctx) {
struct sockaddr_in6 server_addr = { 0 };
int sock_fd;
FD_SCOPED(sock_fd);
server_addr.sin6_family = AF_INET6;
server_addr.sin6_addr = in6addr_any;
@ -39,17 +37,15 @@ static int server_init(const struct server_conf *conf, struct server_ctx *ctx) {
if (setsockopt_int(sock_fd, SOL_SOCKET, SO_REUSEADDR, 1))
FAIL("failed to set SO_REUSEADDR: %s", strerror(errno));
if (setsockopt_int(sock_fd, SOL_SOCKET, SO_TIMESTAMP, 1))
FAIL("failed to set SO_TIMESTAMP: %s", strerror(errno));
if (setsockopt_int(sock_fd, IPPROTO_IPV6, IPV6_V6ONLY, 0))
FAIL("failed to set IPV6_V6ONLY: %s", strerror(errno));
if (bind(sock_fd, (struct sockaddr *)&server_addr, sizeof(server_addr)))
FAIL("failed to bind: %s", strerror(errno));
ctx->fd = sock_fd;
ctx->fd = FD_RELEASE(&sock_fd);
return 0;
}
@ -79,12 +75,21 @@ static void send_error_msg(
{
struct owd_msg_err msg;
owd_hdr_populate(&msg.hdr, OWD_MSG_ERR, OWD_DATA_LEN(msg));
owd_hdr_populate(&msg.hdr, OWD_MSG_ERR, OWD_DATA_LEN(msg), rx->tid);
msg.err = (uint8_t)err;
if (send_to_client(ctx, rx, &msg, OWD_DATA_LEN(msg)))
LOG_ERR("failed to send error message");
}
static int handle_ping(struct server_ctx *ctx, const struct rx_info *rx) {
struct owdhdr hdr;
owd_hdr_populate(&hdr, OWD_MSG_PONG, sizeof(hdr), rx->tid);
if (send_to_client(ctx, rx, &hdr, sizeof(hdr)))
FAIL("failed to send pong");
return 0;
}
static int handle_sync_request(struct server_ctx *ctx, const struct rx_info *rx) {
if (!rx->tv_set)
FAIL("no rx timestamp available");
@ -99,6 +104,7 @@ static int handle_msg(struct server_ctx *ctx, char *buf, ssize_t buf_len, struct
struct rx_info rx;
rx.tv_set = false;
rx.tid = 0;
rx.name = msg_hdr->msg_name;
rx.namelen = msg_hdr->msg_namelen;
@ -116,18 +122,18 @@ static int handle_msg(struct server_ctx *ctx, char *buf, ssize_t buf_len, struct
}
}
if (rx.tv_set) {
printf(" tv_sec=%lld tv_usec=%lld\n", (long long)rx.tv.tv_sec, (long long)rx.tv.tv_usec);
}
if (buf_len < sizeof(struct owdhdr)) {
hdr = (struct owdhdr *)buf;
if (owd_hdr_validate(hdr, buf_len)) {
send_error_msg(ctx, &rx, OWD_ERR_BAD_HEADER);
FAIL("invalid packet: too short");
FAIL("invalid packet: bad header");
}
hdr = (struct owdhdr *)buf;
rx.tid = ntohl(hdr->tid);
switch (hdr->msg_type) {
case OWD_MSG_PING:
return handle_ping(ctx, &rx);
case OWD_MSG_SYNC_REQUEST:
return handle_sync_request(ctx, &rx);
@ -157,8 +163,6 @@ static int server_loop(struct server_ctx *ctx) {
msg_hdr.msg_control = ctrl_buf;
msg_hdr.msg_controllen = sizeof(ctrl_buf);
printf("namelen_before=%zu\n", (size_t)msg_hdr.msg_namelen);
if ((recv_res = recvmsg(ctx->fd, &msg_hdr, 0)) < 0) {
if ((errno == EINTR) || (errno == EAGAIN) || (errno == EWOULDBLOCK))
continue;
@ -167,8 +171,6 @@ static int server_loop(struct server_ctx *ctx) {
continue;
}
printf("namelen_after=%zu\n", (size_t)msg_hdr.msg_namelen);
handle_msg(ctx, msg_buf, recv_res, &msg_hdr);
}
}

Loading…
Cancel
Save