diff --git a/include/protocol.h b/include/protocol.h index 461ed55..cb7e191 100644 --- a/include/protocol.h +++ b/include/protocol.h @@ -10,9 +10,11 @@ #define OWD_DATA_LEN(HDR) (offsetof(__typeof__(HDR), _data_end)) enum owd_msg_type { - OWD_MSG_ERR, - OWD_MSG_SYNC_REQUEST, - OWD_MSG_SYNC_RESPONSE, + OWD_MSG_ERR = 0, + OWD_MSG_PING = 1, + OWD_MSG_PONG = 2, + OWD_MSG_SYNC_REQUEST = 3, + OWD_MSG_SYNC_RESPONSE = 4, }; enum owd_error_code { @@ -30,6 +32,7 @@ struct owdhdr { uint8_t version; uint8_t msg_type; uint16_t length; + uint32_t tid; }; struct owd_msg_err { @@ -44,9 +47,25 @@ struct owd_msg_sync_response { OWD_DECL_DATA_END; }; -static inline void owd_hdr_populate(struct owdhdr *hdr, uint8_t msg_type, uint16_t length) { +static inline void owd_hdr_populate( + struct owdhdr *hdr, + uint8_t msg_type, + uint16_t length, + uint32_t tid) +{ hdr->magic = htonl(OWD_MAGIC); hdr->version = OWD_VERSION; hdr->msg_type = msg_type; hdr->length = htons(length); + hdr->tid = htonl(tid); +} + +static inline int owd_hdr_validate(struct owdhdr *hdr, ssize_t len) { + if (len < sizeof(struct owdhdr)) + return 1; + if (hdr->magic != ntohl(OWD_MAGIC)) + return 1; + if (hdr->version != OWD_VERSION) + return 1; + return 0; } diff --git a/src/server.c b/src/server.c index b901fff..018528c 100644 --- a/src/server.c +++ b/src/server.c @@ -8,6 +8,7 @@ #include #include "server.h" +#include "io.h" #include "util.h" #include "protocol.h" @@ -19,16 +20,13 @@ struct rx_info { const void *name; socklen_t namelen; struct timeval tv; + uint32_t tid; uint8_t tv_set : 1; }; -static int setsockopt_int(int fd, int level, int option, int val) { - return setsockopt(fd, level, option, &val, sizeof(val)); -} - static int server_init(const struct server_conf *conf, struct server_ctx *ctx) { struct sockaddr_in6 server_addr = { 0 }; - int sock_fd; + FD_SCOPED(sock_fd); server_addr.sin6_family = AF_INET6; server_addr.sin6_addr = in6addr_any; @@ -39,17 +37,15 @@ static int server_init(const struct server_conf *conf, struct server_ctx *ctx) { if (setsockopt_int(sock_fd, SOL_SOCKET, SO_REUSEADDR, 1)) FAIL("failed to set SO_REUSEADDR: %s", strerror(errno)); - if (setsockopt_int(sock_fd, SOL_SOCKET, SO_TIMESTAMP, 1)) FAIL("failed to set SO_TIMESTAMP: %s", strerror(errno)); - if (setsockopt_int(sock_fd, IPPROTO_IPV6, IPV6_V6ONLY, 0)) FAIL("failed to set IPV6_V6ONLY: %s", strerror(errno)); if (bind(sock_fd, (struct sockaddr *)&server_addr, sizeof(server_addr))) FAIL("failed to bind: %s", strerror(errno)); - ctx->fd = sock_fd; + ctx->fd = FD_RELEASE(&sock_fd); return 0; } @@ -79,12 +75,21 @@ static void send_error_msg( { struct owd_msg_err msg; - owd_hdr_populate(&msg.hdr, OWD_MSG_ERR, OWD_DATA_LEN(msg)); + owd_hdr_populate(&msg.hdr, OWD_MSG_ERR, OWD_DATA_LEN(msg), rx->tid); msg.err = (uint8_t)err; if (send_to_client(ctx, rx, &msg, OWD_DATA_LEN(msg))) LOG_ERR("failed to send error message"); } +static int handle_ping(struct server_ctx *ctx, const struct rx_info *rx) { + struct owdhdr hdr; + + owd_hdr_populate(&hdr, OWD_MSG_PONG, sizeof(hdr), rx->tid); + if (send_to_client(ctx, rx, &hdr, sizeof(hdr))) + FAIL("failed to send pong"); + return 0; +} + static int handle_sync_request(struct server_ctx *ctx, const struct rx_info *rx) { if (!rx->tv_set) FAIL("no rx timestamp available"); @@ -99,6 +104,7 @@ static int handle_msg(struct server_ctx *ctx, char *buf, ssize_t buf_len, struct struct rx_info rx; rx.tv_set = false; + rx.tid = 0; rx.name = msg_hdr->msg_name; rx.namelen = msg_hdr->msg_namelen; @@ -116,18 +122,18 @@ static int handle_msg(struct server_ctx *ctx, char *buf, ssize_t buf_len, struct } } - if (rx.tv_set) { - printf(" tv_sec=%lld tv_usec=%lld\n", (long long)rx.tv.tv_sec, (long long)rx.tv.tv_usec); - } - - if (buf_len < sizeof(struct owdhdr)) { + hdr = (struct owdhdr *)buf; + if (owd_hdr_validate(hdr, buf_len)) { send_error_msg(ctx, &rx, OWD_ERR_BAD_HEADER); - FAIL("invalid packet: too short"); + FAIL("invalid packet: bad header"); } - hdr = (struct owdhdr *)buf; + rx.tid = ntohl(hdr->tid); switch (hdr->msg_type) { + case OWD_MSG_PING: + return handle_ping(ctx, &rx); + case OWD_MSG_SYNC_REQUEST: return handle_sync_request(ctx, &rx); @@ -157,8 +163,6 @@ static int server_loop(struct server_ctx *ctx) { msg_hdr.msg_control = ctrl_buf; msg_hdr.msg_controllen = sizeof(ctrl_buf); - printf("namelen_before=%zu\n", (size_t)msg_hdr.msg_namelen); - if ((recv_res = recvmsg(ctx->fd, &msg_hdr, 0)) < 0) { if ((errno == EINTR) || (errno == EAGAIN) || (errno == EWOULDBLOCK)) continue; @@ -167,8 +171,6 @@ static int server_loop(struct server_ctx *ctx) { continue; } - printf("namelen_after=%zu\n", (size_t)msg_hdr.msg_namelen); - handle_msg(ctx, msg_buf, recv_res, &msg_hdr); } }