Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

tcp_comm.c 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532
  1. /**
  2. * Copyright (c) 2022 Brian Starkey <stark3y@gmail.com>
  3. *
  4. * Parts based on the Pico W tcp_server example:
  5. * Copyright (c) 2022 Raspberry Pi (Trading) Ltd.
  6. *
  7. * SPDX-License-Identifier: BSD-3-Clause
  8. */
  9. #include <stdlib.h>
  10. #include "pico/cyw43_arch.h"
  11. #include "lwip/pbuf.h"
  12. #include "lwip/tcp.h"
  13. #include "tcp_comm.h"
  14. #define DEBUG_printf printf
  15. #define POLL_TIME_S 5
  16. #define COMM_MAX_NARG 5
  17. enum conn_state {
  18. CONN_STATE_WAIT_FOR_SYNC,
  19. CONN_STATE_READ_OPCODE,
  20. CONN_STATE_READ_ARGS,
  21. CONN_STATE_READ_DATA,
  22. CONN_STATE_HANDLE,
  23. CONN_STATE_WRITE_RESP,
  24. CONN_STATE_WRITE_ERROR,
  25. CONN_STATE_CLOSED,
  26. };
  27. struct tcp_comm_ctx {
  28. struct tcp_pcb *serv_pcb;
  29. volatile bool serv_done;
  30. enum conn_state conn_state;
  31. struct tcp_pcb *client_pcb;
  32. uint8_t buf[(sizeof(uint32_t) * (1 + COMM_MAX_NARG)) + TCP_COMM_MAX_DATA_LEN];
  33. uint16_t rx_bytes_received;
  34. uint16_t rx_bytes_remaining;
  35. uint16_t tx_bytes_sent;
  36. uint16_t tx_bytes_remaining;
  37. uint32_t resp_data_len;
  38. const struct comm_command *cmd;
  39. const struct comm_command *const *cmds;
  40. unsigned int n_cmds;
  41. uint32_t sync_opcode;
  42. };
  43. #define COMM_BUF_OPCODE(_buf) ((uint32_t *)((uint8_t *)(_buf)))
  44. #define COMM_BUF_ARGS(_buf) ((uint32_t *)((uint8_t *)(_buf) + sizeof(uint32_t)))
  45. #define COMM_BUF_BODY(_buf, _nargs) ((uint8_t *)(_buf) + (sizeof(uint32_t) * ((_nargs) + 1)))
  46. static const struct comm_command *find_command_desc(struct tcp_comm_ctx *ctx, uint32_t opcode)
  47. {
  48. unsigned int i;
  49. for (i = 0; i < ctx->n_cmds; i++) {
  50. if (ctx->cmds[i]->opcode == opcode) {
  51. return ctx->cmds[i];
  52. }
  53. }
  54. return NULL;
  55. }
  56. static bool is_error(uint32_t status)
  57. {
  58. return status == TCP_COMM_RSP_ERR;
  59. }
  60. static int tcp_comm_sync_begin(struct tcp_comm_ctx *ctx);
  61. static int tcp_comm_sync_complete(struct tcp_comm_ctx *ctx);
  62. static int tcp_comm_opcode_begin(struct tcp_comm_ctx *ctx);
  63. static int tcp_comm_opcode_complete(struct tcp_comm_ctx *ctx);
  64. static int tcp_comm_args_begin(struct tcp_comm_ctx *ctx);
  65. static int tcp_comm_args_complete(struct tcp_comm_ctx *ctx);
  66. static int tcp_comm_data_begin(struct tcp_comm_ctx *ctx, uint32_t data_len);
  67. static int tcp_comm_data_complete(struct tcp_comm_ctx *ctx);
  68. static int tcp_comm_response_begin(struct tcp_comm_ctx *ctx);
  69. static int tcp_comm_response_complete(struct tcp_comm_ctx *ctx);
  70. static int tcp_comm_error_begin(struct tcp_comm_ctx *ctx);
  71. static int tcp_comm_sync_begin(struct tcp_comm_ctx *ctx)
  72. {
  73. ctx->conn_state = CONN_STATE_WAIT_FOR_SYNC;
  74. ctx->rx_bytes_received = 0;
  75. ctx->rx_bytes_remaining = sizeof(uint32_t);
  76. DEBUG_printf("sync_begin %d\n", ctx->rx_bytes_remaining);
  77. }
  78. static int tcp_comm_sync_complete(struct tcp_comm_ctx *ctx)
  79. {
  80. if (ctx->sync_opcode != *COMM_BUF_OPCODE(ctx->buf)) {
  81. DEBUG_printf("sync not correct: %c%c%c%c\n", ctx->buf[0], ctx->buf[1], ctx->buf[2], ctx->buf[3]);
  82. return tcp_comm_error_begin(ctx);
  83. }
  84. return tcp_comm_opcode_complete(ctx);
  85. }
  86. static int tcp_comm_opcode_begin(struct tcp_comm_ctx *ctx)
  87. {
  88. ctx->conn_state = CONN_STATE_READ_OPCODE;
  89. ctx->rx_bytes_received = 0;
  90. ctx->rx_bytes_remaining = sizeof(uint32_t);
  91. return 0;
  92. }
  93. static int tcp_comm_opcode_complete(struct tcp_comm_ctx *ctx)
  94. {
  95. ctx->cmd = find_command_desc(ctx, *COMM_BUF_OPCODE(ctx->buf));
  96. if (!ctx->cmd) {
  97. DEBUG_printf("no command for '%c%c%c%c'\n", ctx->buf[0], ctx->buf[1], ctx->buf[2], ctx->buf[3]);
  98. return tcp_comm_error_begin(ctx);
  99. } else {
  100. DEBUG_printf("got command '%c%c%c%c'\n", ctx->buf[0], ctx->buf[1], ctx->buf[2], ctx->buf[3]);
  101. }
  102. return tcp_comm_args_begin(ctx);
  103. }
  104. static int tcp_comm_args_begin(struct tcp_comm_ctx *ctx)
  105. {
  106. ctx->conn_state = CONN_STATE_READ_ARGS;
  107. ctx->rx_bytes_received = 0;
  108. ctx->rx_bytes_remaining = ctx->cmd->nargs * sizeof(uint32_t);
  109. if (ctx->cmd->nargs == 0) {
  110. return tcp_comm_args_complete(ctx);
  111. }
  112. return 0;
  113. }
  114. static int tcp_comm_args_complete(struct tcp_comm_ctx *ctx)
  115. {
  116. const struct comm_command *cmd = ctx->cmd;
  117. uint32_t data_len = 0;
  118. if (cmd->size) {
  119. uint32_t status = cmd->size(COMM_BUF_ARGS(ctx->buf),
  120. &data_len,
  121. &ctx->resp_data_len);
  122. if (is_error(status)) {
  123. return tcp_comm_error_begin(ctx);
  124. }
  125. }
  126. return tcp_comm_data_begin(ctx, data_len);
  127. }
  128. static int tcp_comm_data_begin(struct tcp_comm_ctx *ctx, uint32_t data_len)
  129. {
  130. const struct comm_command *cmd = ctx->cmd;
  131. ctx->conn_state = CONN_STATE_READ_DATA;
  132. ctx->rx_bytes_received = 0;
  133. ctx->rx_bytes_remaining = data_len;
  134. if (data_len == 0) {
  135. return tcp_comm_data_complete(ctx);
  136. }
  137. return 0;
  138. }
  139. static int tcp_comm_data_complete(struct tcp_comm_ctx *ctx)
  140. {
  141. const struct comm_command *cmd = ctx->cmd;
  142. if (cmd->handle) {
  143. uint32_t status = cmd->handle(COMM_BUF_ARGS(ctx->buf),
  144. COMM_BUF_BODY(ctx->buf, cmd->nargs),
  145. COMM_BUF_ARGS(ctx->buf),
  146. COMM_BUF_BODY(ctx->buf, cmd->resp_nargs));
  147. if (is_error(status)) {
  148. return tcp_comm_error_begin(ctx);
  149. }
  150. *COMM_BUF_OPCODE(ctx->buf) = status;
  151. } else {
  152. // TODO: Should we just assert(desc->handle)?
  153. *COMM_BUF_OPCODE(ctx->buf) = TCP_COMM_RSP_OK;
  154. }
  155. return tcp_comm_response_begin(ctx);
  156. }
  157. static int tcp_comm_response_begin(struct tcp_comm_ctx *ctx)
  158. {
  159. ctx->conn_state = CONN_STATE_WRITE_RESP;
  160. ctx->tx_bytes_sent = 0;
  161. ctx->tx_bytes_remaining = ctx->resp_data_len + ((ctx->cmd->resp_nargs + 1) * sizeof(uint32_t));
  162. err_t err = tcp_write(ctx->client_pcb, ctx->buf, ctx->tx_bytes_remaining, 0);
  163. if (err != ERR_OK) {
  164. return -1;
  165. }
  166. return 0;
  167. }
  168. static int tcp_comm_error_begin(struct tcp_comm_ctx *ctx)
  169. {
  170. ctx->conn_state = CONN_STATE_WRITE_ERROR;
  171. ctx->tx_bytes_sent = 0;
  172. ctx->tx_bytes_remaining = sizeof(uint32_t);
  173. *COMM_BUF_OPCODE(ctx->buf) = TCP_COMM_RSP_ERR;
  174. err_t err = tcp_write(ctx->client_pcb, ctx->buf, ctx->tx_bytes_remaining, 0);
  175. if (err != ERR_OK) {
  176. return -1;
  177. }
  178. return 0;
  179. }
  180. static int tcp_comm_response_complete(struct tcp_comm_ctx *ctx)
  181. {
  182. return tcp_comm_opcode_begin(ctx);
  183. }
  184. static int tcp_comm_rx_complete(struct tcp_comm_ctx *ctx)
  185. {
  186. switch (ctx->conn_state) {
  187. case CONN_STATE_WAIT_FOR_SYNC:
  188. return tcp_comm_sync_complete(ctx);
  189. case CONN_STATE_READ_OPCODE:
  190. return tcp_comm_opcode_complete(ctx);
  191. case CONN_STATE_READ_ARGS:
  192. return tcp_comm_args_complete(ctx);
  193. case CONN_STATE_READ_DATA:
  194. return tcp_comm_data_complete(ctx);
  195. default:
  196. return -1;
  197. }
  198. }
  199. static int tcp_comm_tx_complete(struct tcp_comm_ctx *ctx)
  200. {
  201. switch (ctx->conn_state) {
  202. case CONN_STATE_WRITE_RESP:
  203. return tcp_comm_response_complete(ctx);
  204. case CONN_STATE_WRITE_ERROR:
  205. return -1;
  206. default:
  207. return -1;
  208. }
  209. }
  210. static err_t tcp_comm_client_close(struct tcp_comm_ctx *ctx)
  211. {
  212. err_t err = ERR_OK;
  213. cyw43_arch_gpio_put (0, false);
  214. ctx->conn_state = CONN_STATE_CLOSED;
  215. if (!ctx->client_pcb) {
  216. return err;
  217. }
  218. tcp_arg(ctx->client_pcb, NULL);
  219. tcp_poll(ctx->client_pcb, NULL, 0);
  220. tcp_sent(ctx->client_pcb, NULL);
  221. tcp_recv(ctx->client_pcb, NULL);
  222. tcp_err(ctx->client_pcb, NULL);
  223. err = tcp_close(ctx->client_pcb);
  224. if (err != ERR_OK) {
  225. DEBUG_printf("close failed %d, calling abort\n", err);
  226. tcp_abort(ctx->client_pcb);
  227. err = ERR_ABRT;
  228. }
  229. ctx->client_pcb = NULL;
  230. return err;
  231. }
  232. err_t tcp_comm_server_close(struct tcp_comm_ctx *ctx)
  233. {
  234. err_t err = ERR_OK;
  235. err = tcp_comm_client_close(ctx);
  236. if ((err != ERR_OK) && ctx->serv_pcb) {
  237. tcp_arg(ctx->serv_pcb, NULL);
  238. tcp_abort(ctx->serv_pcb);
  239. ctx->serv_pcb = NULL;
  240. return ERR_ABRT;
  241. }
  242. if (!ctx->serv_pcb) {
  243. return err;
  244. }
  245. tcp_arg(ctx->serv_pcb, NULL);
  246. err = tcp_close(ctx->serv_pcb);
  247. if (err != ERR_OK) {
  248. tcp_abort(ctx->serv_pcb);
  249. err = ERR_ABRT;
  250. }
  251. ctx->serv_pcb = NULL;
  252. return err;
  253. }
  254. static void tcp_comm_server_complete(void *arg, int status)
  255. {
  256. struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
  257. if (status == 0) {
  258. DEBUG_printf("server completed normally\n");
  259. } else {
  260. DEBUG_printf("server error %d\n", status);
  261. }
  262. tcp_comm_server_close(ctx);
  263. ctx->serv_done = true;
  264. }
  265. static err_t tcp_comm_client_complete(void *arg, int status)
  266. {
  267. struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
  268. if (status == 0) {
  269. DEBUG_printf("conn completed normally\n");
  270. } else {
  271. DEBUG_printf("conn error %d\n", status);
  272. }
  273. return tcp_comm_client_close(ctx);
  274. }
  275. static err_t tcp_comm_client_sent(void *arg, struct tcp_pcb *tpcb, u16_t len)
  276. {
  277. struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
  278. DEBUG_printf("tcp_comm_server_sent %u\n", len);
  279. cyw43_arch_lwip_check();
  280. if (len > ctx->tx_bytes_remaining) {
  281. DEBUG_printf("tx len %d > remaining %d\n", len, ctx->tx_bytes_remaining);
  282. return tcp_comm_client_complete(ctx, ERR_ARG);
  283. }
  284. ctx->tx_bytes_remaining -= len;
  285. ctx->tx_bytes_sent += len;
  286. if (ctx->tx_bytes_remaining == 0) {
  287. int res = tcp_comm_tx_complete(ctx);
  288. if (res) {
  289. return tcp_comm_client_complete(ctx, ERR_ARG);
  290. }
  291. }
  292. return ERR_OK;
  293. }
  294. static err_t tcp_comm_client_recv(void *arg, struct tcp_pcb *tpcb, struct pbuf *p, err_t err)
  295. {
  296. struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
  297. if (!p) {
  298. DEBUG_printf("no pbuf\n");
  299. return tcp_comm_client_complete(ctx, 0);
  300. }
  301. // this method is callback from lwIP, so cyw43_arch_lwip_begin is not required, however you
  302. // can use this method to cause an assertion in debug mode, if this method is called when
  303. // cyw43_arch_lwip_begin IS needed
  304. cyw43_arch_lwip_check();
  305. if (p->tot_len > 0) {
  306. DEBUG_printf("tcp_comm_server_recv %d err %d\n", p->tot_len, err);
  307. if (p->tot_len > ctx->rx_bytes_remaining) {
  308. DEBUG_printf("more data than expected: %d vs %d\n", p->tot_len, ctx->rx_bytes_remaining);
  309. // TODO: Invoking the error response state here feels
  310. // like a bit of a layering violation, but this is a
  311. // protocol error, rather than a failure in the stack
  312. // somewhere, so it's nice to try and report it rather
  313. // than just dropping the connection.
  314. if (tcp_comm_error_begin(ctx)) {
  315. return tcp_comm_client_complete(ctx, ERR_ARG);
  316. }
  317. return ERR_OK;
  318. }
  319. // Receive the buffer
  320. if (pbuf_copy_partial(p, ctx->buf + ctx->rx_bytes_received, p->tot_len, 0) != p->tot_len) {
  321. DEBUG_printf("wrong copy len\n");
  322. return tcp_comm_client_complete(ctx, ERR_ARG);
  323. }
  324. ctx->rx_bytes_received += p->tot_len;
  325. ctx->rx_bytes_remaining -= p->tot_len;
  326. tcp_recved(tpcb, p->tot_len);
  327. if (ctx->rx_bytes_remaining == 0) {
  328. int res = tcp_comm_rx_complete(ctx);
  329. if (res) {
  330. return tcp_comm_client_complete(ctx, ERR_ARG);
  331. }
  332. }
  333. }
  334. pbuf_free(p);
  335. return ERR_OK;
  336. }
  337. static err_t tcp_comm_client_poll(void *arg, struct tcp_pcb *tpcb)
  338. {
  339. DEBUG_printf("tcp_comm_server_poll_fn\n");
  340. return ERR_OK;
  341. }
  342. static void tcp_comm_client_err(void *arg, err_t err)
  343. {
  344. struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
  345. DEBUG_printf("tcp_comm_err %d\n", err);
  346. ctx->client_pcb = NULL;
  347. ctx->conn_state = CONN_STATE_CLOSED;
  348. ctx->rx_bytes_remaining = 0;
  349. cyw43_arch_gpio_put (0, false);
  350. }
  351. static void tcp_comm_client_init(struct tcp_comm_ctx *ctx, struct tcp_pcb *pcb)
  352. {
  353. ctx->client_pcb = pcb;
  354. tcp_arg(pcb, ctx);
  355. cyw43_arch_gpio_put (0, true);
  356. tcp_comm_sync_begin(ctx);
  357. tcp_sent(pcb, tcp_comm_client_sent);
  358. tcp_recv(pcb, tcp_comm_client_recv);
  359. tcp_poll(pcb, tcp_comm_client_poll, POLL_TIME_S * 2);
  360. tcp_err(pcb, tcp_comm_client_err);
  361. }
  362. static err_t tcp_comm_server_accept(void *arg, struct tcp_pcb *client_pcb, err_t err)
  363. {
  364. struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
  365. if (err != ERR_OK || client_pcb == NULL) {
  366. DEBUG_printf("Failure in accept\n");
  367. tcp_comm_server_complete(ctx, err);
  368. return ERR_VAL;
  369. }
  370. DEBUG_printf("Connection opened\n");
  371. if (ctx->client_pcb) {
  372. DEBUG_printf("Already have a connection\n");
  373. tcp_abort(client_pcb);
  374. return ERR_ABRT;
  375. }
  376. tcp_comm_client_init(ctx, client_pcb);
  377. return ERR_OK;
  378. }
  379. err_t tcp_comm_listen(struct tcp_comm_ctx *ctx, uint16_t port)
  380. {
  381. DEBUG_printf("Starting server at %s on port %u\n", ip4addr_ntoa(netif_ip4_addr(netif_list)), port);
  382. ctx->serv_done = false;
  383. struct tcp_pcb *pcb = tcp_new_ip_type(IPADDR_TYPE_ANY);
  384. if (!pcb) {
  385. DEBUG_printf("failed to create pcb\n");
  386. return ERR_MEM;
  387. }
  388. err_t err = tcp_bind(pcb, NULL, port);
  389. if (err) {
  390. DEBUG_printf("failed to bind to port %d\n", port);
  391. tcp_abort(pcb);
  392. return err;
  393. }
  394. ctx->serv_pcb = tcp_listen_with_backlog_and_err(pcb, 1, &err);
  395. if (!ctx->serv_pcb) {
  396. DEBUG_printf("failed to listen: %d\n", err);
  397. return err;
  398. }
  399. tcp_arg(ctx->serv_pcb, ctx);
  400. tcp_accept(ctx->serv_pcb, tcp_comm_server_accept);
  401. return ERR_OK;
  402. }
  403. struct tcp_comm_ctx *tcp_comm_new(const struct comm_command *const *cmds,
  404. unsigned int n_cmds, uint32_t sync_opcode)
  405. {
  406. struct tcp_comm_ctx *ctx = calloc(1, sizeof(struct tcp_comm_ctx));
  407. if (!ctx) {
  408. return NULL;
  409. }
  410. unsigned int i;
  411. for (i = 0; i < n_cmds; i++) {
  412. assert(cmds[i]->nargs <= MAX_NARG);
  413. assert(cmds[i]->resp_nargs <= MAX_NARG);
  414. }
  415. ctx->cmds = cmds;
  416. ctx->n_cmds = n_cmds;
  417. ctx->sync_opcode = sync_opcode;
  418. return ctx;
  419. }
  420. void tcp_comm_delete(struct tcp_comm_ctx *ctx)
  421. {
  422. tcp_comm_server_close(ctx);
  423. free(ctx);
  424. }
  425. bool tcp_comm_server_done(struct tcp_comm_ctx *ctx)
  426. {
  427. return ctx->serv_done;
  428. }