You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.c 14KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585
  1. /**
  2. * Copyright (c) 2022 Brian Starkey <stark3y@gmail.com>
  3. *
  4. * Based on the Pico W tcp_server example:
  5. * Copyright (c) 2022 Raspberry Pi (Trading) Ltd.
  6. *
  7. * SPDX-License-Identifier: BSD-3-Clause
  8. */
  9. #include <string.h>
  10. #include <stdlib.h>
  11. #include "pico/stdlib.h"
  12. #include "pico/cyw43_arch.h"
  13. #include "lwip/pbuf.h"
  14. #include "lwip/tcp.h"
  15. extern const char *wifi_ssid;
  16. extern const char *wifi_pass;
  17. #define TCP_PORT 4242
  18. #define DEBUG_printf printf
  19. #define POLL_TIME_S 5
  20. #define MAX_LEN 2048
  21. #define COMM_MAX_NARG 5
  22. #define COMM_MAX_DATA_LEN 1024
  23. #define COMM_RSP_OK (('O' << 0) | ('K' << 8) | ('O' << 16) | ('K' << 24))
  24. #define COMM_RSP_ERR (('E' << 0) | ('R' << 8) | ('R' << 16) | ('!' << 24))
  25. #define CMD_SYNC (('S' << 0) | ('Y' << 8) | ('N' << 16) | ('C' << 24))
  26. #define RSP_SYNC (('W' << 0) | ('O' << 8) | ('T' << 16) | ('A' << 24))
  27. struct comm_command {
  28. uint32_t opcode;
  29. uint32_t nargs;
  30. uint32_t resp_nargs;
  31. uint32_t (*size)(uint32_t *args_in, uint32_t *data_len_out, uint32_t *resp_data_len_out);
  32. uint32_t (*handle)(uint32_t *args_in, uint8_t *data_in, uint32_t *resp_args_out, uint8_t *resp_data_out);
  33. };
  34. enum conn_state {
  35. CONN_STATE_WAIT_FOR_SYNC,
  36. CONN_STATE_READ_OPCODE,
  37. CONN_STATE_READ_ARGS,
  38. CONN_STATE_READ_DATA,
  39. CONN_STATE_HANDLE,
  40. CONN_STATE_WRITE_RESP,
  41. CONN_STATE_WRITE_ERROR,
  42. CONN_STATE_CLOSED,
  43. };
  44. struct tcp_comm_ctx {
  45. struct tcp_pcb *serv_pcb;
  46. volatile bool serv_done;
  47. enum conn_state conn_state;
  48. struct tcp_pcb *conn_pcb;
  49. uint8_t buf[(sizeof(uint32_t) * (1 + COMM_MAX_NARG)) + COMM_MAX_DATA_LEN];
  50. uint16_t rx_bytes_received;
  51. uint16_t rx_bytes_remaining;
  52. uint16_t tx_bytes_sent;
  53. uint16_t tx_bytes_remaining;
  54. uint32_t resp_data_len;
  55. const struct comm_command *cmd;
  56. const struct comm_command *const *cmds;
  57. unsigned int n_cmds;
  58. uint32_t sync_opcode;
  59. };
  60. #define COMM_BUF_OPCODE(_buf) ((uint32_t *)((uint8_t *)(_buf)))
  61. #define COMM_BUF_ARGS(_buf) ((uint32_t *)((uint8_t *)(_buf) + sizeof(uint32_t)))
  62. #define COMM_BUF_BODY(_buf, _nargs) ((uint8_t *)(_buf) + (sizeof(uint32_t) * ((_nargs) + 1)))
  63. static const struct comm_command *find_command_desc(struct tcp_comm_ctx *ctx, uint32_t opcode)
  64. {
  65. unsigned int i;
  66. for (i = 0; i < ctx->n_cmds; i++) {
  67. if (ctx->cmds[i]->opcode == opcode) {
  68. return ctx->cmds[i];
  69. }
  70. }
  71. return NULL;
  72. }
  73. static bool is_error(uint32_t status)
  74. {
  75. return status == COMM_RSP_ERR;
  76. }
  77. static int tcp_conn_sync_begin(struct tcp_comm_ctx *ctx);
  78. static int tcp_conn_sync_complete(struct tcp_comm_ctx *ctx);
  79. static int tcp_conn_opcode_begin(struct tcp_comm_ctx *ctx);
  80. static int tcp_conn_opcode_complete(struct tcp_comm_ctx *ctx);
  81. static int tcp_conn_args_begin(struct tcp_comm_ctx *ctx);
  82. static int tcp_conn_args_complete(struct tcp_comm_ctx *ctx);
  83. static int tcp_conn_data_begin(struct tcp_comm_ctx *ctx, uint32_t data_len);
  84. static int tcp_conn_data_complete(struct tcp_comm_ctx *ctx);
  85. static int tcp_conn_response_begin(struct tcp_comm_ctx *ctx);
  86. static int tcp_conn_response_complete(struct tcp_comm_ctx *ctx);
  87. static int tcp_conn_error_begin(struct tcp_comm_ctx *ctx);
  88. static int tcp_conn_sync_begin(struct tcp_comm_ctx *ctx)
  89. {
  90. ctx->conn_state = CONN_STATE_WAIT_FOR_SYNC;
  91. ctx->rx_bytes_received = 0;
  92. ctx->rx_bytes_remaining = sizeof(uint32_t);
  93. DEBUG_printf("sync_begin %d\n", ctx->rx_bytes_remaining);
  94. }
  95. static int tcp_conn_sync_complete(struct tcp_comm_ctx *ctx)
  96. {
  97. if (ctx->sync_opcode != *COMM_BUF_OPCODE(ctx->buf)) {
  98. DEBUG_printf("sync not correct: %c%c%c%c\n", ctx->buf[0], ctx->buf[1], ctx->buf[2], ctx->buf[3]);
  99. return tcp_conn_error_begin(ctx);
  100. }
  101. return tcp_conn_opcode_complete(ctx);
  102. }
  103. static int tcp_conn_opcode_begin(struct tcp_comm_ctx *ctx)
  104. {
  105. ctx->conn_state = CONN_STATE_READ_OPCODE;
  106. ctx->rx_bytes_received = 0;
  107. ctx->rx_bytes_remaining = sizeof(uint32_t);
  108. return 0;
  109. }
  110. static int tcp_conn_opcode_complete(struct tcp_comm_ctx *ctx)
  111. {
  112. ctx->cmd = find_command_desc(ctx, *COMM_BUF_OPCODE(ctx->buf));
  113. if (!ctx->cmd) {
  114. DEBUG_printf("no command for '%c%c%c%c'\n", ctx->buf[0], ctx->buf[1], ctx->buf[2], ctx->buf[3]);
  115. return tcp_conn_error_begin(ctx);
  116. } else {
  117. DEBUG_printf("got command '%c%c%c%c'\n", ctx->buf[0], ctx->buf[1], ctx->buf[2], ctx->buf[3]);
  118. }
  119. return tcp_conn_args_begin(ctx);
  120. }
  121. static int tcp_conn_args_begin(struct tcp_comm_ctx *ctx)
  122. {
  123. ctx->conn_state = CONN_STATE_READ_ARGS;
  124. ctx->rx_bytes_received = 0;
  125. ctx->rx_bytes_remaining = ctx->cmd->nargs * sizeof(uint32_t);
  126. if (ctx->cmd->nargs == 0) {
  127. return tcp_conn_args_complete(ctx);
  128. }
  129. return 0;
  130. }
  131. static int tcp_conn_args_complete(struct tcp_comm_ctx *ctx)
  132. {
  133. const struct comm_command *cmd = ctx->cmd;
  134. uint32_t data_len = 0;
  135. if (cmd->size) {
  136. uint32_t status = cmd->size(COMM_BUF_ARGS(ctx->buf),
  137. &data_len,
  138. &ctx->resp_data_len);
  139. if (is_error(status)) {
  140. return tcp_conn_error_begin(ctx);
  141. }
  142. }
  143. return tcp_conn_data_begin(ctx, data_len);
  144. }
  145. static int tcp_conn_data_begin(struct tcp_comm_ctx *ctx, uint32_t data_len)
  146. {
  147. const struct comm_command *cmd = ctx->cmd;
  148. ctx->conn_state = CONN_STATE_READ_DATA;
  149. ctx->rx_bytes_received = 0;
  150. ctx->rx_bytes_remaining = data_len;
  151. if (data_len == 0) {
  152. return tcp_conn_data_complete(ctx);
  153. }
  154. return 0;
  155. }
  156. static int tcp_conn_data_complete(struct tcp_comm_ctx *ctx)
  157. {
  158. const struct comm_command *cmd = ctx->cmd;
  159. if (cmd->handle) {
  160. uint32_t status = cmd->handle(COMM_BUF_ARGS(ctx->buf),
  161. COMM_BUF_BODY(ctx->buf, cmd->nargs),
  162. COMM_BUF_ARGS(ctx->buf),
  163. COMM_BUF_BODY(ctx->buf, cmd->resp_nargs));
  164. if (is_error(status)) {
  165. return tcp_conn_error_begin(ctx);
  166. }
  167. *COMM_BUF_OPCODE(ctx->buf) = status;
  168. } else {
  169. // TODO: Should we just assert(desc->handle)?
  170. *COMM_BUF_OPCODE(ctx->buf) = COMM_RSP_OK;
  171. }
  172. return tcp_conn_response_begin(ctx);
  173. }
  174. static int tcp_conn_response_begin(struct tcp_comm_ctx *ctx)
  175. {
  176. ctx->conn_state = CONN_STATE_WRITE_RESP;
  177. ctx->tx_bytes_sent = 0;
  178. ctx->tx_bytes_remaining = ctx->resp_data_len + ((ctx->cmd->resp_nargs + 1) * sizeof(uint32_t));
  179. err_t err = tcp_write(ctx->conn_pcb, ctx->buf, ctx->tx_bytes_remaining, 0);
  180. if (err != ERR_OK) {
  181. return -1;
  182. }
  183. return 0;
  184. }
  185. static int tcp_conn_error_begin(struct tcp_comm_ctx *ctx)
  186. {
  187. ctx->conn_state = CONN_STATE_WRITE_ERROR;
  188. ctx->tx_bytes_sent = 0;
  189. ctx->tx_bytes_remaining = sizeof(uint32_t);
  190. *COMM_BUF_OPCODE(ctx->buf) = COMM_RSP_ERR;
  191. err_t err = tcp_write(ctx->conn_pcb, ctx->buf, ctx->tx_bytes_remaining, 0);
  192. if (err != ERR_OK) {
  193. return -1;
  194. }
  195. return 0;
  196. }
  197. static int tcp_conn_response_complete(struct tcp_comm_ctx *ctx)
  198. {
  199. return tcp_conn_opcode_begin(ctx);
  200. }
  201. static int tcp_conn_rx_complete(struct tcp_comm_ctx *ctx)
  202. {
  203. switch (ctx->conn_state) {
  204. case CONN_STATE_WAIT_FOR_SYNC:
  205. return tcp_conn_sync_complete(ctx);
  206. case CONN_STATE_READ_OPCODE:
  207. return tcp_conn_opcode_complete(ctx);
  208. case CONN_STATE_READ_ARGS:
  209. return tcp_conn_args_complete(ctx);
  210. case CONN_STATE_READ_DATA:
  211. return tcp_conn_data_complete(ctx);
  212. default:
  213. return -1;
  214. }
  215. }
  216. static int tcp_conn_tx_complete(struct tcp_comm_ctx *ctx)
  217. {
  218. switch (ctx->conn_state) {
  219. case CONN_STATE_WRITE_RESP:
  220. return tcp_conn_response_complete(ctx);
  221. case CONN_STATE_WRITE_ERROR:
  222. return -1;
  223. default:
  224. return -1;
  225. }
  226. }
  227. static err_t tcp_conn_close(struct tcp_comm_ctx *ctx)
  228. {
  229. err_t err = ERR_OK;
  230. cyw43_arch_gpio_put (0, false);
  231. ctx->conn_state = CONN_STATE_CLOSED;
  232. if (!ctx->conn_pcb) {
  233. return err;
  234. }
  235. tcp_arg(ctx->conn_pcb, NULL);
  236. tcp_poll(ctx->conn_pcb, NULL, 0);
  237. tcp_sent(ctx->conn_pcb, NULL);
  238. tcp_recv(ctx->conn_pcb, NULL);
  239. tcp_err(ctx->conn_pcb, NULL);
  240. err = tcp_close(ctx->conn_pcb);
  241. if (err != ERR_OK) {
  242. DEBUG_printf("close failed %d, calling abort\n", err);
  243. tcp_abort(ctx->conn_pcb);
  244. err = ERR_ABRT;
  245. }
  246. ctx->conn_pcb = NULL;
  247. return err;
  248. }
  249. static err_t tcp_server_close(void *arg)
  250. {
  251. struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
  252. err_t err = ERR_OK;
  253. err = tcp_conn_close(ctx);
  254. if ((err != ERR_OK) && ctx->serv_pcb) {
  255. tcp_arg(ctx->serv_pcb, NULL);
  256. tcp_abort(ctx->serv_pcb);
  257. ctx->serv_pcb = NULL;
  258. return ERR_ABRT;
  259. }
  260. if (!ctx->serv_pcb) {
  261. return err;
  262. }
  263. tcp_arg(ctx->serv_pcb, NULL);
  264. err = tcp_close(ctx->serv_pcb);
  265. if (err != ERR_OK) {
  266. tcp_abort(ctx->serv_pcb);
  267. err = ERR_ABRT;
  268. }
  269. ctx->serv_pcb = NULL;
  270. return err;
  271. }
  272. static void tcp_server_complete(void *arg, int status)
  273. {
  274. struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
  275. if (status == 0) {
  276. DEBUG_printf("server completed normally\n");
  277. } else {
  278. DEBUG_printf("server error %d\n", status);
  279. }
  280. tcp_server_close(ctx);
  281. ctx->serv_done = true;
  282. }
  283. static err_t tcp_conn_complete(void *arg, int status)
  284. {
  285. struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
  286. if (status == 0) {
  287. DEBUG_printf("conn completed normally\n");
  288. } else {
  289. DEBUG_printf("conn error %d\n", status);
  290. }
  291. return tcp_conn_close(ctx);
  292. }
  293. static err_t tcp_conn_sent(void *arg, struct tcp_pcb *tpcb, u16_t len)
  294. {
  295. struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
  296. DEBUG_printf("tcp_server_sent %u\n", len);
  297. cyw43_arch_lwip_check();
  298. if (len > ctx->tx_bytes_remaining) {
  299. DEBUG_printf("tx len %d > remaining %d\n", len, ctx->tx_bytes_remaining);
  300. return tcp_conn_complete(ctx, ERR_ARG);
  301. }
  302. ctx->tx_bytes_remaining -= len;
  303. ctx->tx_bytes_sent += len;
  304. if (ctx->tx_bytes_remaining == 0) {
  305. int res = tcp_conn_tx_complete(ctx);
  306. if (res) {
  307. return tcp_conn_complete(ctx, ERR_ARG);
  308. }
  309. }
  310. return ERR_OK;
  311. }
  312. static err_t tcp_conn_recv(void *arg, struct tcp_pcb *tpcb, struct pbuf *p, err_t err)
  313. {
  314. struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
  315. if (!p) {
  316. DEBUG_printf("no pbuf\n");
  317. return tcp_conn_complete(ctx, 0);
  318. }
  319. // this method is callback from lwIP, so cyw43_arch_lwip_begin is not required, however you
  320. // can use this method to cause an assertion in debug mode, if this method is called when
  321. // cyw43_arch_lwip_begin IS needed
  322. cyw43_arch_lwip_check();
  323. if (p->tot_len > 0) {
  324. DEBUG_printf("tcp_server_recv %d err %d\n", p->tot_len, err);
  325. size_t to_copy = p->tot_len > ctx->rx_bytes_remaining ? ctx->rx_bytes_remaining : p->tot_len;
  326. // Receive the buffer
  327. if (pbuf_copy_partial(p, ctx->buf + ctx->rx_bytes_received, to_copy, 0) != to_copy) {
  328. DEBUG_printf("wrong copy len\n");
  329. return tcp_conn_complete(ctx, ERR_ARG);
  330. }
  331. ctx->rx_bytes_received += to_copy;
  332. ctx->rx_bytes_remaining -= to_copy;
  333. tcp_recved(tpcb, p->tot_len);
  334. if (ctx->rx_bytes_remaining == 0) {
  335. int res = tcp_conn_rx_complete(ctx);
  336. if (res) {
  337. return tcp_conn_complete(ctx, ERR_ARG);
  338. }
  339. }
  340. }
  341. pbuf_free(p);
  342. return ERR_OK;
  343. }
  344. static err_t tcp_conn_poll(void *arg, struct tcp_pcb *tpcb)
  345. {
  346. DEBUG_printf("tcp_server_poll_fn\n");
  347. return ERR_OK;
  348. }
  349. static void tcp_conn_err(void *arg, err_t err)
  350. {
  351. struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
  352. DEBUG_printf("tcp_conn_err %d\n", err);
  353. ctx->conn_pcb = NULL;
  354. ctx->conn_state = CONN_STATE_CLOSED;
  355. ctx->rx_bytes_remaining = 0;
  356. cyw43_arch_gpio_put (0, false);
  357. }
  358. static void tcp_conn_init(struct tcp_comm_ctx *ctx, struct tcp_pcb *pcb)
  359. {
  360. ctx->conn_pcb = pcb;
  361. tcp_arg(pcb, ctx);
  362. cyw43_arch_gpio_put (0, true);
  363. tcp_conn_sync_begin(ctx);
  364. tcp_sent(pcb, tcp_conn_sent);
  365. tcp_recv(pcb, tcp_conn_recv);
  366. tcp_poll(pcb, tcp_conn_poll, POLL_TIME_S * 2);
  367. tcp_err(pcb, tcp_conn_err);
  368. }
  369. static err_t tcp_server_accept(void *arg, struct tcp_pcb *client_pcb, err_t err)
  370. {
  371. struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
  372. if (err != ERR_OK || client_pcb == NULL) {
  373. DEBUG_printf("Failure in accept\n");
  374. tcp_server_complete(ctx, err);
  375. return ERR_VAL;
  376. }
  377. DEBUG_printf("Connection opened\n");
  378. if (ctx->conn_pcb) {
  379. DEBUG_printf("Already have a connection\n");
  380. tcp_abort(client_pcb);
  381. return ERR_ABRT;
  382. }
  383. tcp_conn_init(ctx, client_pcb);
  384. return ERR_OK;
  385. }
  386. static err_t tcp_server_listen(struct tcp_comm_ctx *ctx)
  387. {
  388. DEBUG_printf("Starting server at %s on port %u\n", ip4addr_ntoa(netif_ip4_addr(netif_list)), TCP_PORT);
  389. ctx->serv_done = false;
  390. struct tcp_pcb *pcb = tcp_new_ip_type(IPADDR_TYPE_ANY);
  391. if (!pcb) {
  392. DEBUG_printf("failed to create pcb\n");
  393. return ERR_MEM;
  394. }
  395. err_t err = tcp_bind(pcb, NULL, TCP_PORT);
  396. if (err) {
  397. DEBUG_printf("failed to bind to port %d\n", TCP_PORT);
  398. tcp_abort(pcb);
  399. return err;
  400. }
  401. ctx->serv_pcb = tcp_listen_with_backlog_and_err(pcb, 1, &err);
  402. if (!ctx->serv_pcb) {
  403. DEBUG_printf("failed to listen: %d\n", err);
  404. return err;
  405. }
  406. tcp_arg(ctx->serv_pcb, ctx);
  407. tcp_accept(ctx->serv_pcb, tcp_server_accept);
  408. return ERR_OK;
  409. }
  410. static uint32_t handle_sync(uint32_t *args_in, uint8_t *data_in, uint32_t *resp_args_out, uint8_t *resp_data_out)
  411. {
  412. return RSP_SYNC;
  413. }
  414. const struct comm_command util_sync_cmd = {
  415. .opcode = CMD_SYNC,
  416. .nargs = 0,
  417. .resp_nargs = 0,
  418. .size = NULL,
  419. .handle = &handle_sync,
  420. };
  421. static void tcp_comm_init(struct tcp_comm_ctx *ctx, const struct comm_command *const *cmds,
  422. unsigned int n_cmds, uint32_t sync_opcode)
  423. {
  424. unsigned int i;
  425. for (i = 0; i < n_cmds; i++) {
  426. assert(cmds[i]->nargs <= MAX_NARG);
  427. assert(cmds[i]->resp_nargs <= MAX_NARG);
  428. }
  429. memset(ctx, 0, sizeof(*ctx));
  430. ctx->cmds = cmds;
  431. ctx->n_cmds = n_cmds;
  432. ctx->sync_opcode = sync_opcode;
  433. }
  434. int main()
  435. {
  436. stdio_init_all();
  437. sleep_ms(1000);
  438. if (cyw43_arch_init()) {
  439. printf("failed to initialise\n");
  440. return 1;
  441. }
  442. cyw43_arch_enable_sta_mode();
  443. printf("Connecting to WiFi...\n");
  444. if (cyw43_arch_wifi_connect_timeout_ms(wifi_ssid, wifi_pass, CYW43_AUTH_WPA2_AES_PSK, 30000)) {
  445. printf("failed to connect.\n");
  446. return 1;
  447. } else {
  448. printf("Connected.\n");
  449. }
  450. struct tcp_comm_ctx tcp;
  451. const struct comm_command *cmds[] = {
  452. &util_sync_cmd,
  453. };
  454. tcp_comm_init(&tcp, cmds, 1, CMD_SYNC);
  455. for ( ; ; ) {
  456. err_t err = tcp_server_listen(&tcp);
  457. if (err != ERR_OK) {
  458. printf("Failed to start server: %d\n", err);
  459. sleep_ms(1000);
  460. continue;
  461. }
  462. while (!tcp.serv_done) {
  463. cyw43_arch_poll();
  464. sleep_ms(10);
  465. }
  466. }
  467. cyw43_arch_deinit();
  468. return 0;
  469. }