You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

tcp_comm.c 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554
  1. /**
  2. * Copyright (c) 2022 Brian Starkey <stark3y@gmail.com>
  3. *
  4. * Parts based on the Pico W tcp_server example:
  5. * Copyright (c) 2022 Raspberry Pi (Trading) Ltd.
  6. *
  7. * SPDX-License-Identifier: BSD-3-Clause
  8. */
  9. #include <stdlib.h>
  10. #include "pico/cyw43_arch.h"
  11. #include "lwip/pbuf.h"
  12. #include "lwip/tcp.h"
  13. #include "tcp_comm.h"
  14. #ifdef DEBUG
  15. #include <stdio.h>
  16. #define DEBUG_printf(...) printf(__VA_ARGS__)
  17. #else
  18. #define DEBUG_printf(...) { }
  19. #endif
  20. #define POLL_TIME_S 5
  21. #define COMM_MAX_NARG 5
  22. enum conn_state {
  23. CONN_STATE_WAIT_FOR_SYNC,
  24. CONN_STATE_READ_OPCODE,
  25. CONN_STATE_READ_ARGS,
  26. CONN_STATE_READ_DATA,
  27. CONN_STATE_HANDLE,
  28. CONN_STATE_WRITE_RESP,
  29. CONN_STATE_WRITE_ERROR,
  30. CONN_STATE_CLOSED,
  31. };
  32. struct tcp_comm_ctx {
  33. struct tcp_pcb *serv_pcb;
  34. volatile bool serv_done;
  35. enum conn_state conn_state;
  36. struct tcp_pcb *client_pcb;
  37. // Note: sizeof(buf) is used elsewhere, so if this is changed to not
  38. // be an array, those will need updating
  39. uint8_t buf[(sizeof(uint32_t) * (1 + COMM_MAX_NARG)) + TCP_COMM_MAX_DATA_LEN];
  40. uint16_t rx_start_offs;
  41. uint16_t rx_bytes_received;
  42. uint16_t rx_bytes_needed;
  43. uint16_t tx_bytes_sent;
  44. uint16_t tx_bytes_remaining;
  45. uint32_t resp_data_len;
  46. const struct comm_command *cmd;
  47. const struct comm_command *const *cmds;
  48. unsigned int n_cmds;
  49. uint32_t sync_opcode;
  50. };
  51. #define COMM_BUF_OPCODE(_buf) ((uint32_t *)((uint8_t *)(_buf)))
  52. #define COMM_BUF_ARGS(_buf) ((uint32_t *)((uint8_t *)(_buf) + sizeof(uint32_t)))
  53. #define COMM_BUF_BODY(_buf, _nargs) ((uint8_t *)(_buf) + (sizeof(uint32_t) * ((_nargs) + 1)))
  54. static const struct comm_command *find_command_desc(struct tcp_comm_ctx *ctx, uint32_t opcode)
  55. {
  56. unsigned int i;
  57. for (i = 0; i < ctx->n_cmds; i++) {
  58. if (ctx->cmds[i]->opcode == opcode) {
  59. return ctx->cmds[i];
  60. }
  61. }
  62. return NULL;
  63. }
  64. static bool is_error(uint32_t status)
  65. {
  66. return status == TCP_COMM_RSP_ERR;
  67. }
  68. static int tcp_comm_sync_begin(struct tcp_comm_ctx *ctx);
  69. static int tcp_comm_sync_complete(struct tcp_comm_ctx *ctx);
  70. static int tcp_comm_opcode_begin(struct tcp_comm_ctx *ctx);
  71. static int tcp_comm_opcode_complete(struct tcp_comm_ctx *ctx);
  72. static int tcp_comm_args_begin(struct tcp_comm_ctx *ctx);
  73. static int tcp_comm_args_complete(struct tcp_comm_ctx *ctx);
  74. static int tcp_comm_data_begin(struct tcp_comm_ctx *ctx, uint32_t data_len);
  75. static int tcp_comm_data_complete(struct tcp_comm_ctx *ctx);
  76. static int tcp_comm_response_begin(struct tcp_comm_ctx *ctx);
  77. static int tcp_comm_response_complete(struct tcp_comm_ctx *ctx);
  78. static int tcp_comm_error_begin(struct tcp_comm_ctx *ctx);
  79. static int tcp_comm_sync_begin(struct tcp_comm_ctx *ctx)
  80. {
  81. ctx->conn_state = CONN_STATE_WAIT_FOR_SYNC;
  82. ctx->rx_bytes_needed = sizeof(uint32_t);
  83. return 0;
  84. }
  85. static int tcp_comm_sync_complete(struct tcp_comm_ctx *ctx)
  86. {
  87. if (ctx->sync_opcode != *COMM_BUF_OPCODE(ctx->buf)) {
  88. DEBUG_printf("sync not correct: %c%c%c%c\n", ctx->buf[0], ctx->buf[1], ctx->buf[2], ctx->buf[3]);
  89. return tcp_comm_error_begin(ctx);
  90. }
  91. return tcp_comm_opcode_complete(ctx);
  92. }
  93. static int tcp_comm_opcode_begin(struct tcp_comm_ctx *ctx)
  94. {
  95. ctx->conn_state = CONN_STATE_READ_OPCODE;
  96. ctx->rx_bytes_needed = sizeof(uint32_t);
  97. return 0;
  98. }
  99. static int tcp_comm_opcode_complete(struct tcp_comm_ctx *ctx)
  100. {
  101. ctx->cmd = find_command_desc(ctx, *COMM_BUF_OPCODE(ctx->buf));
  102. if (!ctx->cmd) {
  103. DEBUG_printf("no command for '%c%c%c%c'\n", ctx->buf[0], ctx->buf[1], ctx->buf[2], ctx->buf[3]);
  104. return tcp_comm_error_begin(ctx);
  105. } else {
  106. DEBUG_printf("got command '%c%c%c%c'\n", ctx->buf[0], ctx->buf[1], ctx->buf[2], ctx->buf[3]);
  107. }
  108. return tcp_comm_args_begin(ctx);
  109. }
  110. static int tcp_comm_args_begin(struct tcp_comm_ctx *ctx)
  111. {
  112. ctx->conn_state = CONN_STATE_READ_ARGS;
  113. ctx->rx_bytes_needed = ctx->cmd->nargs * sizeof(uint32_t);
  114. if (ctx->cmd->nargs == 0) {
  115. return tcp_comm_args_complete(ctx);
  116. }
  117. return 0;
  118. }
  119. static int tcp_comm_args_complete(struct tcp_comm_ctx *ctx)
  120. {
  121. const struct comm_command *cmd = ctx->cmd;
  122. uint32_t data_len = 0;
  123. if (cmd->size) {
  124. uint32_t status = cmd->size(COMM_BUF_ARGS(ctx->buf),
  125. &data_len,
  126. &ctx->resp_data_len);
  127. if (is_error(status)) {
  128. return tcp_comm_error_begin(ctx);
  129. }
  130. }
  131. return tcp_comm_data_begin(ctx, data_len);
  132. }
  133. static int tcp_comm_data_begin(struct tcp_comm_ctx *ctx, uint32_t data_len)
  134. {
  135. ctx->conn_state = CONN_STATE_READ_DATA;
  136. ctx->rx_bytes_needed = data_len;
  137. if (data_len == 0) {
  138. return tcp_comm_data_complete(ctx);
  139. }
  140. return 0;
  141. }
  142. static int tcp_comm_data_complete(struct tcp_comm_ctx *ctx)
  143. {
  144. const struct comm_command *cmd = ctx->cmd;
  145. if (cmd->handle) {
  146. uint32_t status = cmd->handle(COMM_BUF_ARGS(ctx->buf),
  147. COMM_BUF_BODY(ctx->buf, cmd->nargs),
  148. COMM_BUF_ARGS(ctx->buf),
  149. COMM_BUF_BODY(ctx->buf, cmd->resp_nargs));
  150. if (is_error(status)) {
  151. return tcp_comm_error_begin(ctx);
  152. }
  153. *COMM_BUF_OPCODE(ctx->buf) = status;
  154. } else {
  155. // TODO: Should we just assert(desc->handle)?
  156. *COMM_BUF_OPCODE(ctx->buf) = TCP_COMM_RSP_OK;
  157. }
  158. return tcp_comm_response_begin(ctx);
  159. }
  160. static int tcp_comm_response_begin(struct tcp_comm_ctx *ctx)
  161. {
  162. ctx->conn_state = CONN_STATE_WRITE_RESP;
  163. ctx->tx_bytes_sent = 0;
  164. ctx->tx_bytes_remaining = ctx->resp_data_len + ((ctx->cmd->resp_nargs + 1) * sizeof(uint32_t));
  165. err_t err = tcp_write(ctx->client_pcb, ctx->buf, ctx->tx_bytes_remaining, 0);
  166. if (err != ERR_OK) {
  167. return -1;
  168. }
  169. return 0;
  170. }
  171. static int tcp_comm_error_begin(struct tcp_comm_ctx *ctx)
  172. {
  173. ctx->conn_state = CONN_STATE_WRITE_ERROR;
  174. ctx->tx_bytes_sent = 0;
  175. ctx->tx_bytes_remaining = sizeof(uint32_t);
  176. *COMM_BUF_OPCODE(ctx->buf) = TCP_COMM_RSP_ERR;
  177. err_t err = tcp_write(ctx->client_pcb, ctx->buf, ctx->tx_bytes_remaining, 0);
  178. if (err != ERR_OK) {
  179. return -1;
  180. }
  181. return 0;
  182. }
  183. static int tcp_comm_response_complete(struct tcp_comm_ctx *ctx)
  184. {
  185. return tcp_comm_opcode_begin(ctx);
  186. }
  187. static int tcp_comm_rx_complete(struct tcp_comm_ctx *ctx)
  188. {
  189. switch (ctx->conn_state) {
  190. case CONN_STATE_WAIT_FOR_SYNC:
  191. return tcp_comm_sync_complete(ctx);
  192. case CONN_STATE_READ_OPCODE:
  193. return tcp_comm_opcode_complete(ctx);
  194. case CONN_STATE_READ_ARGS:
  195. return tcp_comm_args_complete(ctx);
  196. case CONN_STATE_READ_DATA:
  197. return tcp_comm_data_complete(ctx);
  198. default:
  199. return -1;
  200. }
  201. }
  202. static int tcp_comm_tx_complete(struct tcp_comm_ctx *ctx)
  203. {
  204. switch (ctx->conn_state) {
  205. case CONN_STATE_WRITE_RESP:
  206. return tcp_comm_response_complete(ctx);
  207. case CONN_STATE_WRITE_ERROR:
  208. return -1;
  209. default:
  210. return -1;
  211. }
  212. }
  213. static err_t tcp_comm_client_close(struct tcp_comm_ctx *ctx)
  214. {
  215. err_t err = ERR_OK;
  216. cyw43_arch_gpio_put (0, false);
  217. ctx->conn_state = CONN_STATE_CLOSED;
  218. if (!ctx->client_pcb) {
  219. return err;
  220. }
  221. tcp_arg(ctx->client_pcb, NULL);
  222. tcp_poll(ctx->client_pcb, NULL, 0);
  223. tcp_sent(ctx->client_pcb, NULL);
  224. tcp_recv(ctx->client_pcb, NULL);
  225. tcp_err(ctx->client_pcb, NULL);
  226. err = tcp_close(ctx->client_pcb);
  227. if (err != ERR_OK) {
  228. DEBUG_printf("close failed %d, calling abort\n", err);
  229. tcp_abort(ctx->client_pcb);
  230. err = ERR_ABRT;
  231. }
  232. ctx->client_pcb = NULL;
  233. return err;
  234. }
  235. err_t tcp_comm_server_close(struct tcp_comm_ctx *ctx)
  236. {
  237. err_t err = ERR_OK;
  238. err = tcp_comm_client_close(ctx);
  239. if ((err != ERR_OK) && ctx->serv_pcb) {
  240. tcp_arg(ctx->serv_pcb, NULL);
  241. tcp_abort(ctx->serv_pcb);
  242. ctx->serv_pcb = NULL;
  243. return ERR_ABRT;
  244. }
  245. if (!ctx->serv_pcb) {
  246. return err;
  247. }
  248. tcp_arg(ctx->serv_pcb, NULL);
  249. err = tcp_close(ctx->serv_pcb);
  250. if (err != ERR_OK) {
  251. tcp_abort(ctx->serv_pcb);
  252. err = ERR_ABRT;
  253. }
  254. ctx->serv_pcb = NULL;
  255. return err;
  256. }
  257. static void tcp_comm_server_complete(void *arg, int status)
  258. {
  259. struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
  260. if (status == 0) {
  261. DEBUG_printf("server completed normally\n");
  262. } else {
  263. DEBUG_printf("server error %d\n", status);
  264. }
  265. tcp_comm_server_close(ctx);
  266. ctx->serv_done = true;
  267. }
  268. static err_t tcp_comm_client_complete(void *arg, int status)
  269. {
  270. struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
  271. if (status == 0) {
  272. DEBUG_printf("conn completed normally\n");
  273. } else {
  274. DEBUG_printf("conn error %d\n", status);
  275. }
  276. return tcp_comm_client_close(ctx);
  277. }
  278. static err_t tcp_comm_client_sent(void *arg, struct tcp_pcb *tpcb, u16_t len)
  279. {
  280. struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
  281. DEBUG_printf("tcp_comm_server_sent %u\n", len);
  282. cyw43_arch_lwip_check();
  283. if (len > ctx->tx_bytes_remaining) {
  284. DEBUG_printf("tx len %d > remaining %d\n", len, ctx->tx_bytes_remaining);
  285. return tcp_comm_client_complete(ctx, ERR_ARG);
  286. }
  287. ctx->tx_bytes_remaining -= len;
  288. ctx->tx_bytes_sent += len;
  289. if (ctx->tx_bytes_remaining == 0) {
  290. int res = tcp_comm_tx_complete(ctx);
  291. if (res) {
  292. return tcp_comm_client_complete(ctx, ERR_ARG);
  293. }
  294. }
  295. return ERR_OK;
  296. }
  297. static err_t tcp_comm_client_recv(void *arg, struct tcp_pcb *tpcb, struct pbuf *p, err_t err)
  298. {
  299. struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
  300. if (!p) {
  301. DEBUG_printf("no pbuf\n");
  302. return tcp_comm_client_complete(ctx, 0);
  303. }
  304. // this method is callback from lwIP, so cyw43_arch_lwip_begin is not required, however you
  305. // can use this method to cause an assertion in debug mode, if this method is called when
  306. // cyw43_arch_lwip_begin IS needed
  307. cyw43_arch_lwip_check();
  308. if (p->tot_len > 0) {
  309. DEBUG_printf("tcp_comm_server_recv %d err %d\n", p->tot_len, err);
  310. if (p->tot_len > (sizeof(ctx->buf) - ctx->rx_bytes_received)) {
  311. // Doesn't fit in buffer at all. Protocol error.
  312. DEBUG_printf("not enough space in buffer: %d vs %d\n", p->tot_len, sizeof(ctx->buf) - ctx->rx_bytes_received);
  313. // TODO: Invoking the error response state here feels
  314. // like a bit of a layering violation, but this is a
  315. // protocol error, rather than a failure in the stack
  316. // somewhere, so it's nice to try and report it rather
  317. // than just dropping the connection.
  318. if (tcp_comm_error_begin(ctx)) {
  319. return tcp_comm_client_complete(ctx, ERR_ARG);
  320. }
  321. return ERR_OK;
  322. } else if (p->tot_len > (sizeof(ctx->buf) - (ctx->rx_start_offs + ctx->rx_bytes_received))) {
  323. // There will be space, but we need to shift the data back
  324. // to the start of the buffer
  325. DEBUG_printf("memmove %d bytes to make space for %d bytes\n", ctx->rx_bytes_received, p->tot_len);
  326. memmove(ctx->buf, ctx->buf + ctx->rx_start_offs, ctx->rx_bytes_received);
  327. ctx->rx_start_offs = 0;
  328. }
  329. uint8_t *dst = ctx->buf + ctx->rx_start_offs + ctx->rx_bytes_received;
  330. // We can always handle the full packet
  331. if (pbuf_copy_partial(p, dst, p->tot_len, 0) != p->tot_len) {
  332. DEBUG_printf("wrong copy len\n");
  333. return tcp_comm_client_complete(ctx, ERR_ARG);
  334. }
  335. ctx->rx_bytes_received += p->tot_len;
  336. tcp_recved(tpcb, p->tot_len);
  337. while (ctx->rx_bytes_received >= ctx->rx_bytes_needed) {
  338. uint16_t consumed = ctx->rx_bytes_needed;
  339. int res = tcp_comm_rx_complete(ctx);
  340. if (res) {
  341. return tcp_comm_client_complete(ctx, ERR_ARG);
  342. }
  343. ctx->rx_start_offs += consumed;
  344. ctx->rx_bytes_received -= consumed;
  345. if (ctx->rx_bytes_received == 0) {
  346. ctx->rx_start_offs = 0;
  347. break;
  348. }
  349. }
  350. }
  351. pbuf_free(p);
  352. return ERR_OK;
  353. }
  354. static err_t tcp_comm_client_poll(void *arg, struct tcp_pcb *tpcb)
  355. {
  356. DEBUG_printf("tcp_comm_server_poll_fn\n");
  357. return ERR_OK;
  358. }
  359. static void tcp_comm_client_err(void *arg, err_t err)
  360. {
  361. struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
  362. DEBUG_printf("tcp_comm_err %d\n", err);
  363. ctx->client_pcb = NULL;
  364. ctx->conn_state = CONN_STATE_CLOSED;
  365. ctx->rx_bytes_needed = 0;
  366. cyw43_arch_gpio_put (0, false);
  367. }
  368. static void tcp_comm_client_init(struct tcp_comm_ctx *ctx, struct tcp_pcb *pcb)
  369. {
  370. ctx->client_pcb = pcb;
  371. tcp_arg(pcb, ctx);
  372. cyw43_arch_gpio_put (0, true);
  373. tcp_comm_sync_begin(ctx);
  374. tcp_sent(pcb, tcp_comm_client_sent);
  375. tcp_recv(pcb, tcp_comm_client_recv);
  376. tcp_poll(pcb, tcp_comm_client_poll, POLL_TIME_S * 2);
  377. tcp_err(pcb, tcp_comm_client_err);
  378. }
  379. static err_t tcp_comm_server_accept(void *arg, struct tcp_pcb *client_pcb, err_t err)
  380. {
  381. struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
  382. if (err != ERR_OK || client_pcb == NULL) {
  383. DEBUG_printf("Failure in accept\n");
  384. tcp_comm_server_complete(ctx, err);
  385. return ERR_VAL;
  386. }
  387. DEBUG_printf("Connection opened\n");
  388. if (ctx->client_pcb) {
  389. DEBUG_printf("Already have a connection\n");
  390. tcp_abort(client_pcb);
  391. return ERR_ABRT;
  392. }
  393. tcp_comm_client_init(ctx, client_pcb);
  394. return ERR_OK;
  395. }
  396. err_t tcp_comm_listen(struct tcp_comm_ctx *ctx, uint16_t port)
  397. {
  398. DEBUG_printf("Starting server at %s on port %u\n", ip4addr_ntoa(netif_ip4_addr(netif_list)), port);
  399. ctx->serv_done = false;
  400. struct tcp_pcb *pcb = tcp_new_ip_type(IPADDR_TYPE_ANY);
  401. if (!pcb) {
  402. DEBUG_printf("failed to create pcb\n");
  403. return ERR_MEM;
  404. }
  405. err_t err = tcp_bind(pcb, NULL, port);
  406. if (err) {
  407. DEBUG_printf("failed to bind to port %d\n", port);
  408. tcp_abort(pcb);
  409. return err;
  410. }
  411. ctx->serv_pcb = tcp_listen_with_backlog_and_err(pcb, 1, &err);
  412. if (!ctx->serv_pcb) {
  413. DEBUG_printf("failed to listen: %d\n", err);
  414. return err;
  415. }
  416. tcp_arg(ctx->serv_pcb, ctx);
  417. tcp_accept(ctx->serv_pcb, tcp_comm_server_accept);
  418. return ERR_OK;
  419. }
  420. struct tcp_comm_ctx *tcp_comm_new(const struct comm_command *const *cmds,
  421. unsigned int n_cmds, uint32_t sync_opcode)
  422. {
  423. struct tcp_comm_ctx *ctx = calloc(1, sizeof(struct tcp_comm_ctx));
  424. if (!ctx) {
  425. return NULL;
  426. }
  427. unsigned int i;
  428. for (i = 0; i < n_cmds; i++) {
  429. assert(cmds[i]->nargs <= COMM_MAX_NARG);
  430. assert(cmds[i]->resp_nargs <= COMM_MAX_NARG);
  431. }
  432. ctx->cmds = cmds;
  433. ctx->n_cmds = n_cmds;
  434. ctx->sync_opcode = sync_opcode;
  435. return ctx;
  436. }
  437. void tcp_comm_delete(struct tcp_comm_ctx *ctx)
  438. {
  439. tcp_comm_server_close(ctx);
  440. free(ctx);
  441. }
  442. bool tcp_comm_server_done(struct tcp_comm_ctx *ctx)
  443. {
  444. return ctx->serv_done;
  445. }