You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

tcp_comm.c 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549
  1. /**
  2. * Copyright (c) 2022 Brian Starkey <stark3y@gmail.com>
  3. *
  4. * Parts based on the Pico W tcp_server example:
  5. * Copyright (c) 2022 Raspberry Pi (Trading) Ltd.
  6. *
  7. * SPDX-License-Identifier: BSD-3-Clause
  8. */
  9. #include <stdlib.h>
  10. #include "pico/cyw43_arch.h"
  11. #include "lwip/pbuf.h"
  12. #include "lwip/tcp.h"
  13. #include "tcp_comm.h"
  14. #define DEBUG_printf printf
  15. #define POLL_TIME_S 5
  16. #define COMM_MAX_NARG 5
  17. enum conn_state {
  18. CONN_STATE_WAIT_FOR_SYNC,
  19. CONN_STATE_READ_OPCODE,
  20. CONN_STATE_READ_ARGS,
  21. CONN_STATE_READ_DATA,
  22. CONN_STATE_HANDLE,
  23. CONN_STATE_WRITE_RESP,
  24. CONN_STATE_WRITE_ERROR,
  25. CONN_STATE_CLOSED,
  26. };
  27. struct tcp_comm_ctx {
  28. struct tcp_pcb *serv_pcb;
  29. volatile bool serv_done;
  30. enum conn_state conn_state;
  31. struct tcp_pcb *client_pcb;
  32. // Note: sizeof(buf) is used elsewhere, so if this is changed to not
  33. // be an array, those will need updating
  34. uint8_t buf[(sizeof(uint32_t) * (1 + COMM_MAX_NARG)) + TCP_COMM_MAX_DATA_LEN];
  35. uint16_t rx_start_offs;
  36. uint16_t rx_bytes_received;
  37. uint16_t rx_bytes_needed;
  38. uint16_t tx_bytes_sent;
  39. uint16_t tx_bytes_remaining;
  40. uint32_t resp_data_len;
  41. const struct comm_command *cmd;
  42. const struct comm_command *const *cmds;
  43. unsigned int n_cmds;
  44. uint32_t sync_opcode;
  45. };
  46. #define COMM_BUF_OPCODE(_buf) ((uint32_t *)((uint8_t *)(_buf)))
  47. #define COMM_BUF_ARGS(_buf) ((uint32_t *)((uint8_t *)(_buf) + sizeof(uint32_t)))
  48. #define COMM_BUF_BODY(_buf, _nargs) ((uint8_t *)(_buf) + (sizeof(uint32_t) * ((_nargs) + 1)))
  49. static const struct comm_command *find_command_desc(struct tcp_comm_ctx *ctx, uint32_t opcode)
  50. {
  51. unsigned int i;
  52. for (i = 0; i < ctx->n_cmds; i++) {
  53. if (ctx->cmds[i]->opcode == opcode) {
  54. return ctx->cmds[i];
  55. }
  56. }
  57. return NULL;
  58. }
  59. static bool is_error(uint32_t status)
  60. {
  61. return status == TCP_COMM_RSP_ERR;
  62. }
  63. static int tcp_comm_sync_begin(struct tcp_comm_ctx *ctx);
  64. static int tcp_comm_sync_complete(struct tcp_comm_ctx *ctx);
  65. static int tcp_comm_opcode_begin(struct tcp_comm_ctx *ctx);
  66. static int tcp_comm_opcode_complete(struct tcp_comm_ctx *ctx);
  67. static int tcp_comm_args_begin(struct tcp_comm_ctx *ctx);
  68. static int tcp_comm_args_complete(struct tcp_comm_ctx *ctx);
  69. static int tcp_comm_data_begin(struct tcp_comm_ctx *ctx, uint32_t data_len);
  70. static int tcp_comm_data_complete(struct tcp_comm_ctx *ctx);
  71. static int tcp_comm_response_begin(struct tcp_comm_ctx *ctx);
  72. static int tcp_comm_response_complete(struct tcp_comm_ctx *ctx);
  73. static int tcp_comm_error_begin(struct tcp_comm_ctx *ctx);
  74. static int tcp_comm_sync_begin(struct tcp_comm_ctx *ctx)
  75. {
  76. ctx->conn_state = CONN_STATE_WAIT_FOR_SYNC;
  77. ctx->rx_bytes_needed = sizeof(uint32_t);
  78. }
  79. static int tcp_comm_sync_complete(struct tcp_comm_ctx *ctx)
  80. {
  81. if (ctx->sync_opcode != *COMM_BUF_OPCODE(ctx->buf)) {
  82. DEBUG_printf("sync not correct: %c%c%c%c\n", ctx->buf[0], ctx->buf[1], ctx->buf[2], ctx->buf[3]);
  83. return tcp_comm_error_begin(ctx);
  84. }
  85. return tcp_comm_opcode_complete(ctx);
  86. }
  87. static int tcp_comm_opcode_begin(struct tcp_comm_ctx *ctx)
  88. {
  89. ctx->conn_state = CONN_STATE_READ_OPCODE;
  90. ctx->rx_bytes_needed = sizeof(uint32_t);
  91. return 0;
  92. }
  93. static int tcp_comm_opcode_complete(struct tcp_comm_ctx *ctx)
  94. {
  95. ctx->cmd = find_command_desc(ctx, *COMM_BUF_OPCODE(ctx->buf));
  96. if (!ctx->cmd) {
  97. DEBUG_printf("no command for '%c%c%c%c'\n", ctx->buf[0], ctx->buf[1], ctx->buf[2], ctx->buf[3]);
  98. return tcp_comm_error_begin(ctx);
  99. } else {
  100. DEBUG_printf("got command '%c%c%c%c'\n", ctx->buf[0], ctx->buf[1], ctx->buf[2], ctx->buf[3]);
  101. }
  102. return tcp_comm_args_begin(ctx);
  103. }
  104. static int tcp_comm_args_begin(struct tcp_comm_ctx *ctx)
  105. {
  106. ctx->conn_state = CONN_STATE_READ_ARGS;
  107. ctx->rx_bytes_needed = ctx->cmd->nargs * sizeof(uint32_t);
  108. if (ctx->cmd->nargs == 0) {
  109. return tcp_comm_args_complete(ctx);
  110. }
  111. return 0;
  112. }
  113. static int tcp_comm_args_complete(struct tcp_comm_ctx *ctx)
  114. {
  115. const struct comm_command *cmd = ctx->cmd;
  116. uint32_t data_len = 0;
  117. if (cmd->size) {
  118. uint32_t status = cmd->size(COMM_BUF_ARGS(ctx->buf),
  119. &data_len,
  120. &ctx->resp_data_len);
  121. if (is_error(status)) {
  122. return tcp_comm_error_begin(ctx);
  123. }
  124. }
  125. return tcp_comm_data_begin(ctx, data_len);
  126. }
  127. static int tcp_comm_data_begin(struct tcp_comm_ctx *ctx, uint32_t data_len)
  128. {
  129. const struct comm_command *cmd = ctx->cmd;
  130. ctx->conn_state = CONN_STATE_READ_DATA;
  131. ctx->rx_bytes_needed = data_len;
  132. if (data_len == 0) {
  133. return tcp_comm_data_complete(ctx);
  134. }
  135. return 0;
  136. }
  137. static int tcp_comm_data_complete(struct tcp_comm_ctx *ctx)
  138. {
  139. const struct comm_command *cmd = ctx->cmd;
  140. if (cmd->handle) {
  141. uint32_t status = cmd->handle(COMM_BUF_ARGS(ctx->buf),
  142. COMM_BUF_BODY(ctx->buf, cmd->nargs),
  143. COMM_BUF_ARGS(ctx->buf),
  144. COMM_BUF_BODY(ctx->buf, cmd->resp_nargs));
  145. if (is_error(status)) {
  146. return tcp_comm_error_begin(ctx);
  147. }
  148. *COMM_BUF_OPCODE(ctx->buf) = status;
  149. } else {
  150. // TODO: Should we just assert(desc->handle)?
  151. *COMM_BUF_OPCODE(ctx->buf) = TCP_COMM_RSP_OK;
  152. }
  153. return tcp_comm_response_begin(ctx);
  154. }
  155. static int tcp_comm_response_begin(struct tcp_comm_ctx *ctx)
  156. {
  157. ctx->conn_state = CONN_STATE_WRITE_RESP;
  158. ctx->tx_bytes_sent = 0;
  159. ctx->tx_bytes_remaining = ctx->resp_data_len + ((ctx->cmd->resp_nargs + 1) * sizeof(uint32_t));
  160. err_t err = tcp_write(ctx->client_pcb, ctx->buf, ctx->tx_bytes_remaining, 0);
  161. if (err != ERR_OK) {
  162. return -1;
  163. }
  164. return 0;
  165. }
  166. static int tcp_comm_error_begin(struct tcp_comm_ctx *ctx)
  167. {
  168. ctx->conn_state = CONN_STATE_WRITE_ERROR;
  169. ctx->tx_bytes_sent = 0;
  170. ctx->tx_bytes_remaining = sizeof(uint32_t);
  171. *COMM_BUF_OPCODE(ctx->buf) = TCP_COMM_RSP_ERR;
  172. err_t err = tcp_write(ctx->client_pcb, ctx->buf, ctx->tx_bytes_remaining, 0);
  173. if (err != ERR_OK) {
  174. return -1;
  175. }
  176. return 0;
  177. }
  178. static int tcp_comm_response_complete(struct tcp_comm_ctx *ctx)
  179. {
  180. return tcp_comm_opcode_begin(ctx);
  181. }
  182. static int tcp_comm_rx_complete(struct tcp_comm_ctx *ctx)
  183. {
  184. switch (ctx->conn_state) {
  185. case CONN_STATE_WAIT_FOR_SYNC:
  186. return tcp_comm_sync_complete(ctx);
  187. case CONN_STATE_READ_OPCODE:
  188. return tcp_comm_opcode_complete(ctx);
  189. case CONN_STATE_READ_ARGS:
  190. return tcp_comm_args_complete(ctx);
  191. case CONN_STATE_READ_DATA:
  192. return tcp_comm_data_complete(ctx);
  193. default:
  194. return -1;
  195. }
  196. }
  197. static int tcp_comm_tx_complete(struct tcp_comm_ctx *ctx)
  198. {
  199. switch (ctx->conn_state) {
  200. case CONN_STATE_WRITE_RESP:
  201. return tcp_comm_response_complete(ctx);
  202. case CONN_STATE_WRITE_ERROR:
  203. return -1;
  204. default:
  205. return -1;
  206. }
  207. }
  208. static err_t tcp_comm_client_close(struct tcp_comm_ctx *ctx)
  209. {
  210. err_t err = ERR_OK;
  211. cyw43_arch_gpio_put (0, false);
  212. ctx->conn_state = CONN_STATE_CLOSED;
  213. if (!ctx->client_pcb) {
  214. return err;
  215. }
  216. tcp_arg(ctx->client_pcb, NULL);
  217. tcp_poll(ctx->client_pcb, NULL, 0);
  218. tcp_sent(ctx->client_pcb, NULL);
  219. tcp_recv(ctx->client_pcb, NULL);
  220. tcp_err(ctx->client_pcb, NULL);
  221. err = tcp_close(ctx->client_pcb);
  222. if (err != ERR_OK) {
  223. DEBUG_printf("close failed %d, calling abort\n", err);
  224. tcp_abort(ctx->client_pcb);
  225. err = ERR_ABRT;
  226. }
  227. ctx->client_pcb = NULL;
  228. return err;
  229. }
  230. err_t tcp_comm_server_close(struct tcp_comm_ctx *ctx)
  231. {
  232. err_t err = ERR_OK;
  233. err = tcp_comm_client_close(ctx);
  234. if ((err != ERR_OK) && ctx->serv_pcb) {
  235. tcp_arg(ctx->serv_pcb, NULL);
  236. tcp_abort(ctx->serv_pcb);
  237. ctx->serv_pcb = NULL;
  238. return ERR_ABRT;
  239. }
  240. if (!ctx->serv_pcb) {
  241. return err;
  242. }
  243. tcp_arg(ctx->serv_pcb, NULL);
  244. err = tcp_close(ctx->serv_pcb);
  245. if (err != ERR_OK) {
  246. tcp_abort(ctx->serv_pcb);
  247. err = ERR_ABRT;
  248. }
  249. ctx->serv_pcb = NULL;
  250. return err;
  251. }
  252. static void tcp_comm_server_complete(void *arg, int status)
  253. {
  254. struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
  255. if (status == 0) {
  256. DEBUG_printf("server completed normally\n");
  257. } else {
  258. DEBUG_printf("server error %d\n", status);
  259. }
  260. tcp_comm_server_close(ctx);
  261. ctx->serv_done = true;
  262. }
  263. static err_t tcp_comm_client_complete(void *arg, int status)
  264. {
  265. struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
  266. if (status == 0) {
  267. DEBUG_printf("conn completed normally\n");
  268. } else {
  269. DEBUG_printf("conn error %d\n", status);
  270. }
  271. return tcp_comm_client_close(ctx);
  272. }
  273. static err_t tcp_comm_client_sent(void *arg, struct tcp_pcb *tpcb, u16_t len)
  274. {
  275. struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
  276. DEBUG_printf("tcp_comm_server_sent %u\n", len);
  277. cyw43_arch_lwip_check();
  278. if (len > ctx->tx_bytes_remaining) {
  279. DEBUG_printf("tx len %d > remaining %d\n", len, ctx->tx_bytes_remaining);
  280. return tcp_comm_client_complete(ctx, ERR_ARG);
  281. }
  282. ctx->tx_bytes_remaining -= len;
  283. ctx->tx_bytes_sent += len;
  284. if (ctx->tx_bytes_remaining == 0) {
  285. int res = tcp_comm_tx_complete(ctx);
  286. if (res) {
  287. return tcp_comm_client_complete(ctx, ERR_ARG);
  288. }
  289. }
  290. return ERR_OK;
  291. }
  292. static err_t tcp_comm_client_recv(void *arg, struct tcp_pcb *tpcb, struct pbuf *p, err_t err)
  293. {
  294. struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
  295. if (!p) {
  296. DEBUG_printf("no pbuf\n");
  297. return tcp_comm_client_complete(ctx, 0);
  298. }
  299. // this method is callback from lwIP, so cyw43_arch_lwip_begin is not required, however you
  300. // can use this method to cause an assertion in debug mode, if this method is called when
  301. // cyw43_arch_lwip_begin IS needed
  302. cyw43_arch_lwip_check();
  303. if (p->tot_len > 0) {
  304. DEBUG_printf("tcp_comm_server_recv %d err %d\n", p->tot_len, err);
  305. if (p->tot_len > (sizeof(ctx->buf) - ctx->rx_bytes_received)) {
  306. // Doesn't fit in buffer at all. Protocol error.
  307. DEBUG_printf("not enough space in buffer: %d vs %d\n", p->tot_len, sizeof(ctx->buf) - ctx->rx_bytes_received);
  308. // TODO: Invoking the error response state here feels
  309. // like a bit of a layering violation, but this is a
  310. // protocol error, rather than a failure in the stack
  311. // somewhere, so it's nice to try and report it rather
  312. // than just dropping the connection.
  313. if (tcp_comm_error_begin(ctx)) {
  314. return tcp_comm_client_complete(ctx, ERR_ARG);
  315. }
  316. return ERR_OK;
  317. } else if (p->tot_len > (sizeof(ctx->buf) - (ctx->rx_start_offs + ctx->rx_bytes_received))) {
  318. // There will be space, but we need to shift the data back
  319. // to the start of the buffer
  320. DEBUG_printf("memmove %d bytes to make space for %d bytes\n", ctx->rx_bytes_received, p->tot_len);
  321. memmove(ctx->buf, ctx->buf + ctx->rx_start_offs, ctx->rx_bytes_received);
  322. ctx->rx_start_offs = 0;
  323. }
  324. uint8_t *dst = ctx->buf + ctx->rx_start_offs + ctx->rx_bytes_received;
  325. // We can always handle the full packet
  326. if (pbuf_copy_partial(p, dst, p->tot_len, 0) != p->tot_len) {
  327. DEBUG_printf("wrong copy len\n");
  328. return tcp_comm_client_complete(ctx, ERR_ARG);
  329. }
  330. ctx->rx_bytes_received += p->tot_len;
  331. tcp_recved(tpcb, p->tot_len);
  332. while (ctx->rx_bytes_received >= ctx->rx_bytes_needed) {
  333. uint16_t consumed = ctx->rx_bytes_needed;
  334. int res = tcp_comm_rx_complete(ctx);
  335. if (res) {
  336. return tcp_comm_client_complete(ctx, ERR_ARG);
  337. }
  338. ctx->rx_start_offs += consumed;
  339. ctx->rx_bytes_received -= consumed;
  340. if (ctx->rx_bytes_received == 0) {
  341. ctx->rx_start_offs = 0;
  342. break;
  343. }
  344. }
  345. }
  346. pbuf_free(p);
  347. return ERR_OK;
  348. }
  349. static err_t tcp_comm_client_poll(void *arg, struct tcp_pcb *tpcb)
  350. {
  351. DEBUG_printf("tcp_comm_server_poll_fn\n");
  352. return ERR_OK;
  353. }
  354. static void tcp_comm_client_err(void *arg, err_t err)
  355. {
  356. struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
  357. DEBUG_printf("tcp_comm_err %d\n", err);
  358. ctx->client_pcb = NULL;
  359. ctx->conn_state = CONN_STATE_CLOSED;
  360. ctx->rx_bytes_needed = 0;
  361. cyw43_arch_gpio_put (0, false);
  362. }
  363. static void tcp_comm_client_init(struct tcp_comm_ctx *ctx, struct tcp_pcb *pcb)
  364. {
  365. ctx->client_pcb = pcb;
  366. tcp_arg(pcb, ctx);
  367. cyw43_arch_gpio_put (0, true);
  368. tcp_comm_sync_begin(ctx);
  369. tcp_sent(pcb, tcp_comm_client_sent);
  370. tcp_recv(pcb, tcp_comm_client_recv);
  371. tcp_poll(pcb, tcp_comm_client_poll, POLL_TIME_S * 2);
  372. tcp_err(pcb, tcp_comm_client_err);
  373. }
  374. static err_t tcp_comm_server_accept(void *arg, struct tcp_pcb *client_pcb, err_t err)
  375. {
  376. struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
  377. if (err != ERR_OK || client_pcb == NULL) {
  378. DEBUG_printf("Failure in accept\n");
  379. tcp_comm_server_complete(ctx, err);
  380. return ERR_VAL;
  381. }
  382. DEBUG_printf("Connection opened\n");
  383. if (ctx->client_pcb) {
  384. DEBUG_printf("Already have a connection\n");
  385. tcp_abort(client_pcb);
  386. return ERR_ABRT;
  387. }
  388. tcp_comm_client_init(ctx, client_pcb);
  389. return ERR_OK;
  390. }
  391. err_t tcp_comm_listen(struct tcp_comm_ctx *ctx, uint16_t port)
  392. {
  393. DEBUG_printf("Starting server at %s on port %u\n", ip4addr_ntoa(netif_ip4_addr(netif_list)), port);
  394. ctx->serv_done = false;
  395. struct tcp_pcb *pcb = tcp_new_ip_type(IPADDR_TYPE_ANY);
  396. if (!pcb) {
  397. DEBUG_printf("failed to create pcb\n");
  398. return ERR_MEM;
  399. }
  400. err_t err = tcp_bind(pcb, NULL, port);
  401. if (err) {
  402. DEBUG_printf("failed to bind to port %d\n", port);
  403. tcp_abort(pcb);
  404. return err;
  405. }
  406. ctx->serv_pcb = tcp_listen_with_backlog_and_err(pcb, 1, &err);
  407. if (!ctx->serv_pcb) {
  408. DEBUG_printf("failed to listen: %d\n", err);
  409. return err;
  410. }
  411. tcp_arg(ctx->serv_pcb, ctx);
  412. tcp_accept(ctx->serv_pcb, tcp_comm_server_accept);
  413. return ERR_OK;
  414. }
  415. struct tcp_comm_ctx *tcp_comm_new(const struct comm_command *const *cmds,
  416. unsigned int n_cmds, uint32_t sync_opcode)
  417. {
  418. struct tcp_comm_ctx *ctx = calloc(1, sizeof(struct tcp_comm_ctx));
  419. if (!ctx) {
  420. return NULL;
  421. }
  422. unsigned int i;
  423. for (i = 0; i < n_cmds; i++) {
  424. assert(cmds[i]->nargs <= MAX_NARG);
  425. assert(cmds[i]->resp_nargs <= MAX_NARG);
  426. }
  427. ctx->cmds = cmds;
  428. ctx->n_cmds = n_cmds;
  429. ctx->sync_opcode = sync_opcode;
  430. return ctx;
  431. }
  432. void tcp_comm_delete(struct tcp_comm_ctx *ctx)
  433. {
  434. tcp_comm_server_close(ctx);
  435. free(ctx);
  436. }
  437. bool tcp_comm_server_done(struct tcp_comm_ctx *ctx)
  438. {
  439. return ctx->serv_done;
  440. }