Browse Source

Factor out the TCP code

Now isolated for easy re-use. Should split it into a separate library.
Brian Starkey 2 years ago
parent
commit
e0da727e96
4 changed files with 560 additions and 505 deletions
  1. 1
    0
      CMakeLists.txt
  2. 4
    505
      main.c
  3. 525
    0
      tcp_comm.c
  4. 30
    0
      tcp_comm.h

+ 1
- 0
CMakeLists.txt View File

@@ -10,6 +10,7 @@ pico_sdk_init()
10 10
 add_executable(picowota
11 11
 	main.c
12 12
 	creds.c
13
+	tcp_comm.c
13 14
 )
14 15
 
15 16
 pico_enable_stdio_usb(picowota 1)

+ 4
- 505
main.c View File

@@ -13,503 +13,18 @@
13 13
 #include "pico/stdlib.h"
14 14
 #include "pico/cyw43_arch.h"
15 15
 
16
-#include "lwip/pbuf.h"
17
-#include "lwip/tcp.h"
16
+#include "tcp_comm.h"
18 17
 
19 18
 extern const char *wifi_ssid;
20 19
 extern const char *wifi_pass;
21 20
 
22 21
 #define TCP_PORT 4242
23
-#define DEBUG_printf printf
24
-#define POLL_TIME_S 5
25 22
 
26 23
 #define MAX_LEN 2048
27 24
 
28
-#define COMM_MAX_NARG     5
29
-#define COMM_MAX_DATA_LEN 1024
30
-
31
-#define COMM_RSP_OK       (('O' << 0) | ('K' << 8) | ('O' << 16) | ('K' << 24))
32
-#define COMM_RSP_ERR      (('E' << 0) | ('R' << 8) | ('R' << 16) | ('!' << 24))
33 25
 #define CMD_SYNC          (('S' << 0) | ('Y' << 8) | ('N' << 16) | ('C' << 24))
34 26
 #define RSP_SYNC          (('W' << 0) | ('O' << 8) | ('T' << 16) | ('A' << 24))
35 27
 
36
-struct comm_command {
37
-	uint32_t opcode;
38
-	uint32_t nargs;
39
-	uint32_t resp_nargs;
40
-	uint32_t (*size)(uint32_t *args_in, uint32_t *data_len_out, uint32_t *resp_data_len_out);
41
-	uint32_t (*handle)(uint32_t *args_in, uint8_t *data_in, uint32_t *resp_args_out, uint8_t *resp_data_out);
42
-};
43
-
44
-enum conn_state {
45
-	CONN_STATE_WAIT_FOR_SYNC,
46
-	CONN_STATE_READ_OPCODE,
47
-	CONN_STATE_READ_ARGS,
48
-	CONN_STATE_READ_DATA,
49
-	CONN_STATE_HANDLE,
50
-	CONN_STATE_WRITE_RESP,
51
-	CONN_STATE_WRITE_ERROR,
52
-	CONN_STATE_CLOSED,
53
-};
54
-
55
-struct tcp_comm_ctx {
56
-	struct tcp_pcb *serv_pcb;
57
-	volatile bool serv_done;
58
-	enum conn_state conn_state;
59
-
60
-	struct tcp_pcb *conn_pcb;
61
-	uint8_t buf[(sizeof(uint32_t) * (1 + COMM_MAX_NARG)) + COMM_MAX_DATA_LEN];
62
-	uint16_t rx_bytes_received;
63
-	uint16_t rx_bytes_remaining;
64
-
65
-	uint16_t tx_bytes_sent;
66
-	uint16_t tx_bytes_remaining;
67
-
68
-	uint32_t resp_data_len;
69
-
70
-	const struct comm_command *cmd;
71
-	const struct comm_command *const *cmds;
72
-	unsigned int n_cmds;
73
-	uint32_t sync_opcode;
74
-};
75
-
76
-#define COMM_BUF_OPCODE(_buf)       ((uint32_t *)((uint8_t *)(_buf)))
77
-#define COMM_BUF_ARGS(_buf)         ((uint32_t *)((uint8_t *)(_buf) + sizeof(uint32_t)))
78
-#define COMM_BUF_BODY(_buf, _nargs) ((uint8_t *)(_buf) + (sizeof(uint32_t) * ((_nargs) + 1)))
79
-
80
-static const struct comm_command *find_command_desc(struct tcp_comm_ctx *ctx, uint32_t opcode)
81
-{
82
-	unsigned int i;
83
-
84
-	for (i = 0; i < ctx->n_cmds; i++) {
85
-		if (ctx->cmds[i]->opcode == opcode) {
86
-			return ctx->cmds[i];
87
-		}
88
-	}
89
-
90
-	return NULL;
91
-}
92
-
93
-static bool is_error(uint32_t status)
94
-{
95
-	return status == COMM_RSP_ERR;
96
-}
97
-
98
-static int tcp_conn_sync_begin(struct tcp_comm_ctx *ctx);
99
-static int tcp_conn_sync_complete(struct tcp_comm_ctx *ctx);
100
-static int tcp_conn_opcode_begin(struct tcp_comm_ctx *ctx);
101
-static int tcp_conn_opcode_complete(struct tcp_comm_ctx *ctx);
102
-static int tcp_conn_args_begin(struct tcp_comm_ctx *ctx);
103
-static int tcp_conn_args_complete(struct tcp_comm_ctx *ctx);
104
-static int tcp_conn_data_begin(struct tcp_comm_ctx *ctx, uint32_t data_len);
105
-static int tcp_conn_data_complete(struct tcp_comm_ctx *ctx);
106
-static int tcp_conn_response_begin(struct tcp_comm_ctx *ctx);
107
-static int tcp_conn_response_complete(struct tcp_comm_ctx *ctx);
108
-static int tcp_conn_error_begin(struct tcp_comm_ctx *ctx);
109
-
110
-static int tcp_conn_sync_begin(struct tcp_comm_ctx *ctx)
111
-{
112
-	ctx->conn_state = CONN_STATE_WAIT_FOR_SYNC;
113
-	ctx->rx_bytes_received = 0;
114
-	ctx->rx_bytes_remaining = sizeof(uint32_t);
115
-
116
-	DEBUG_printf("sync_begin %d\n", ctx->rx_bytes_remaining);
117
-}
118
-
119
-static int tcp_conn_sync_complete(struct tcp_comm_ctx *ctx)
120
-{
121
-	if (ctx->sync_opcode != *COMM_BUF_OPCODE(ctx->buf)) {
122
-		DEBUG_printf("sync not correct: %c%c%c%c\n", ctx->buf[0], ctx->buf[1], ctx->buf[2], ctx->buf[3]);
123
-		return tcp_conn_error_begin(ctx);
124
-	}
125
-
126
-	return tcp_conn_opcode_complete(ctx);
127
-}
128
-
129
-static int tcp_conn_opcode_begin(struct tcp_comm_ctx *ctx)
130
-{
131
-	ctx->conn_state = CONN_STATE_READ_OPCODE;
132
-	ctx->rx_bytes_received = 0;
133
-	ctx->rx_bytes_remaining = sizeof(uint32_t);
134
-
135
-	return 0;
136
-}
137
-
138
-static int tcp_conn_opcode_complete(struct tcp_comm_ctx *ctx)
139
-{
140
-	ctx->cmd = find_command_desc(ctx, *COMM_BUF_OPCODE(ctx->buf));
141
-	if (!ctx->cmd) {
142
-		DEBUG_printf("no command for '%c%c%c%c'\n", ctx->buf[0], ctx->buf[1], ctx->buf[2], ctx->buf[3]);
143
-		return tcp_conn_error_begin(ctx);
144
-	} else {
145
-		DEBUG_printf("got command '%c%c%c%c'\n", ctx->buf[0], ctx->buf[1], ctx->buf[2], ctx->buf[3]);
146
-	}
147
-
148
-	return tcp_conn_args_begin(ctx);
149
-}
150
-
151
-static int tcp_conn_args_begin(struct tcp_comm_ctx *ctx)
152
-{
153
-	ctx->conn_state = CONN_STATE_READ_ARGS;
154
-	ctx->rx_bytes_received = 0;
155
-	ctx->rx_bytes_remaining = ctx->cmd->nargs * sizeof(uint32_t);
156
-
157
-	if (ctx->cmd->nargs == 0) {
158
-		return tcp_conn_args_complete(ctx);
159
-	}
160
-
161
-	return 0;
162
-}
163
-
164
-static int tcp_conn_args_complete(struct tcp_comm_ctx *ctx)
165
-{
166
-	const struct comm_command *cmd = ctx->cmd;
167
-
168
-	uint32_t data_len = 0;
169
-
170
-	if (cmd->size) {
171
-		uint32_t status = cmd->size(COMM_BUF_ARGS(ctx->buf),
172
-					    &data_len,
173
-					    &ctx->resp_data_len);
174
-		if (is_error(status)) {
175
-			return tcp_conn_error_begin(ctx);
176
-		}
177
-	}
178
-
179
-	return tcp_conn_data_begin(ctx, data_len);
180
-}
181
-
182
-static int tcp_conn_data_begin(struct tcp_comm_ctx *ctx, uint32_t data_len)
183
-{
184
-	const struct comm_command *cmd = ctx->cmd;
185
-
186
-	ctx->conn_state = CONN_STATE_READ_DATA;
187
-	ctx->rx_bytes_received = 0;
188
-	ctx->rx_bytes_remaining = data_len;
189
-
190
-	if (data_len == 0) {
191
-		return tcp_conn_data_complete(ctx);
192
-	}
193
-
194
-
195
-	return 0;
196
-}
197
-
198
-static int tcp_conn_data_complete(struct tcp_comm_ctx *ctx)
199
-{
200
-	const struct comm_command *cmd = ctx->cmd;
201
-
202
-	if (cmd->handle) {
203
-		uint32_t status = cmd->handle(COMM_BUF_ARGS(ctx->buf),
204
-					      COMM_BUF_BODY(ctx->buf, cmd->nargs),
205
-					      COMM_BUF_ARGS(ctx->buf),
206
-					      COMM_BUF_BODY(ctx->buf, cmd->resp_nargs));
207
-		if (is_error(status)) {
208
-			return tcp_conn_error_begin(ctx);
209
-		}
210
-
211
-		*COMM_BUF_OPCODE(ctx->buf) = status;
212
-	} else {
213
-		// TODO: Should we just assert(desc->handle)?
214
-		*COMM_BUF_OPCODE(ctx->buf) = COMM_RSP_OK;
215
-	}
216
-
217
-	return tcp_conn_response_begin(ctx);
218
-}
219
-
220
-static int tcp_conn_response_begin(struct tcp_comm_ctx *ctx)
221
-{
222
-	ctx->conn_state = CONN_STATE_WRITE_RESP;
223
-	ctx->tx_bytes_sent = 0;
224
-	ctx->tx_bytes_remaining = ctx->resp_data_len + ((ctx->cmd->resp_nargs + 1) * sizeof(uint32_t));
225
-
226
-	err_t err = tcp_write(ctx->conn_pcb, ctx->buf, ctx->tx_bytes_remaining, 0);
227
-	if (err != ERR_OK) {
228
-		return -1;
229
-	}
230
-
231
-	return 0;
232
-}
233
-
234
-static int tcp_conn_error_begin(struct tcp_comm_ctx *ctx)
235
-{
236
-	ctx->conn_state = CONN_STATE_WRITE_ERROR;
237
-	ctx->tx_bytes_sent = 0;
238
-	ctx->tx_bytes_remaining = sizeof(uint32_t);
239
-
240
-	*COMM_BUF_OPCODE(ctx->buf) = COMM_RSP_ERR;
241
-
242
-	err_t err = tcp_write(ctx->conn_pcb, ctx->buf, ctx->tx_bytes_remaining, 0);
243
-	if (err != ERR_OK) {
244
-		return -1;
245
-	}
246
-
247
-	return 0;
248
-}
249
-
250
-
251
-static int tcp_conn_response_complete(struct tcp_comm_ctx *ctx)
252
-{
253
-	return tcp_conn_opcode_begin(ctx);
254
-}
255
-
256
-static int tcp_conn_rx_complete(struct tcp_comm_ctx *ctx)
257
-{
258
-	switch (ctx->conn_state) {
259
-	case CONN_STATE_WAIT_FOR_SYNC:
260
-		return tcp_conn_sync_complete(ctx);
261
-	case CONN_STATE_READ_OPCODE:
262
-		return tcp_conn_opcode_complete(ctx);
263
-	case CONN_STATE_READ_ARGS:
264
-		return tcp_conn_args_complete(ctx);
265
-	case CONN_STATE_READ_DATA:
266
-		return tcp_conn_data_complete(ctx);
267
-	default:
268
-		return -1;
269
-	}
270
-}
271
-
272
-static int tcp_conn_tx_complete(struct tcp_comm_ctx *ctx)
273
-{
274
-	switch (ctx->conn_state) {
275
-	case CONN_STATE_WRITE_RESP:
276
-		return tcp_conn_response_complete(ctx);
277
-	case CONN_STATE_WRITE_ERROR:
278
-		return -1;
279
-	default:
280
-		return -1;
281
-	}
282
-}
283
-
284
-static err_t tcp_conn_close(struct tcp_comm_ctx *ctx)
285
-{
286
-	err_t err = ERR_OK;
287
-
288
-	cyw43_arch_gpio_put (0, false);
289
-	ctx->conn_state = CONN_STATE_CLOSED;
290
-
291
-	if (!ctx->conn_pcb) {
292
-		return err;
293
-	}
294
-
295
-	tcp_arg(ctx->conn_pcb, NULL);
296
-	tcp_poll(ctx->conn_pcb, NULL, 0);
297
-	tcp_sent(ctx->conn_pcb, NULL);
298
-	tcp_recv(ctx->conn_pcb, NULL);
299
-	tcp_err(ctx->conn_pcb, NULL);
300
-	err = tcp_close(ctx->conn_pcb);
301
-	if (err != ERR_OK) {
302
-		DEBUG_printf("close failed %d, calling abort\n", err);
303
-		tcp_abort(ctx->conn_pcb);
304
-		err = ERR_ABRT;
305
-	}
306
-
307
-	ctx->conn_pcb = NULL;
308
-
309
-	return err;
310
-}
311
-
312
-static err_t tcp_server_close(void *arg)
313
-{
314
-	struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
315
-	err_t err = ERR_OK;
316
-
317
-	err = tcp_conn_close(ctx);
318
-	if ((err != ERR_OK) && ctx->serv_pcb) {
319
-		tcp_arg(ctx->serv_pcb, NULL);
320
-		tcp_abort(ctx->serv_pcb);
321
-		ctx->serv_pcb = NULL;
322
-		return ERR_ABRT;
323
-	}
324
-
325
-	if (!ctx->serv_pcb) {
326
-		return err;
327
-	}
328
-
329
-	tcp_arg(ctx->serv_pcb, NULL);
330
-	err = tcp_close(ctx->serv_pcb);
331
-	if (err != ERR_OK) {
332
-		tcp_abort(ctx->serv_pcb);
333
-		err = ERR_ABRT;
334
-	}
335
-	ctx->serv_pcb = NULL;
336
-
337
-	return err;
338
-}
339
-
340
-static void tcp_server_complete(void *arg, int status)
341
-{
342
-	struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
343
-	if (status == 0) {
344
-		DEBUG_printf("server completed normally\n");
345
-	} else {
346
-		DEBUG_printf("server error %d\n", status);
347
-	}
348
-
349
-	tcp_server_close(ctx);
350
-	ctx->serv_done = true;
351
-}
352
-
353
-static err_t tcp_conn_complete(void *arg, int status)
354
-{
355
-	struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
356
-	if (status == 0) {
357
-		DEBUG_printf("conn completed normally\n");
358
-	} else {
359
-		DEBUG_printf("conn error %d\n", status);
360
-	}
361
-	return tcp_conn_close(ctx);
362
-}
363
-
364
-static err_t tcp_conn_sent(void *arg, struct tcp_pcb *tpcb, u16_t len)
365
-{
366
-	struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
367
-	DEBUG_printf("tcp_server_sent %u\n", len);
368
-
369
-	cyw43_arch_lwip_check();
370
-	if (len > ctx->tx_bytes_remaining) {
371
-		DEBUG_printf("tx len %d > remaining %d\n", len, ctx->tx_bytes_remaining);
372
-		return tcp_conn_complete(ctx, ERR_ARG);
373
-	}
374
-
375
-	ctx->tx_bytes_remaining -= len;
376
-	ctx->tx_bytes_sent += len;
377
-
378
-	if (ctx->tx_bytes_remaining == 0) {
379
-		int res = tcp_conn_tx_complete(ctx);
380
-		if (res) {
381
-			return tcp_conn_complete(ctx, ERR_ARG);
382
-		}
383
-	}
384
-
385
-	return ERR_OK;
386
-}
387
-
388
-static err_t tcp_conn_recv(void *arg, struct tcp_pcb *tpcb, struct pbuf *p, err_t err)
389
-{
390
-	struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
391
-	if (!p) {
392
-		DEBUG_printf("no pbuf\n");
393
-		return tcp_conn_complete(ctx, 0);
394
-	}
395
-
396
-	// this method is callback from lwIP, so cyw43_arch_lwip_begin is not required, however you
397
-	// can use this method to cause an assertion in debug mode, if this method is called when
398
-	// cyw43_arch_lwip_begin IS needed
399
-	cyw43_arch_lwip_check();
400
-	if (p->tot_len > 0) {
401
-		DEBUG_printf("tcp_server_recv %d err %d\n", p->tot_len, err);
402
-
403
-		size_t to_copy = p->tot_len > ctx->rx_bytes_remaining ? ctx->rx_bytes_remaining : p->tot_len;
404
-
405
-		// Receive the buffer
406
-		if (pbuf_copy_partial(p, ctx->buf + ctx->rx_bytes_received, to_copy, 0) != to_copy) {
407
-			DEBUG_printf("wrong copy len\n");
408
-			return tcp_conn_complete(ctx, ERR_ARG);
409
-		}
410
-
411
-		ctx->rx_bytes_received += to_copy;
412
-		ctx->rx_bytes_remaining -= to_copy;
413
-		tcp_recved(tpcb, p->tot_len);
414
-
415
-		if (ctx->rx_bytes_remaining == 0) {
416
-			int res = tcp_conn_rx_complete(ctx);
417
-			if (res) {
418
-				return tcp_conn_complete(ctx, ERR_ARG);
419
-			}
420
-		}
421
-	}
422
-	pbuf_free(p);
423
-
424
-	return ERR_OK;
425
-}
426
-
427
-static err_t tcp_conn_poll(void *arg, struct tcp_pcb *tpcb)
428
-{
429
-	DEBUG_printf("tcp_server_poll_fn\n");
430
-	return ERR_OK;
431
-}
432
-
433
-static void tcp_conn_err(void *arg, err_t err)
434
-{
435
-	struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
436
-
437
-	DEBUG_printf("tcp_conn_err %d\n", err);
438
-
439
-	ctx->conn_pcb = NULL;
440
-	ctx->conn_state = CONN_STATE_CLOSED;
441
-	ctx->rx_bytes_remaining = 0;
442
-	cyw43_arch_gpio_put (0, false);
443
-}
444
-
445
-static void tcp_conn_init(struct tcp_comm_ctx *ctx, struct tcp_pcb *pcb)
446
-{
447
-	ctx->conn_pcb = pcb;
448
-	tcp_arg(pcb, ctx);
449
-
450
-	cyw43_arch_gpio_put (0, true);
451
-
452
-	tcp_conn_sync_begin(ctx);
453
-
454
-	tcp_sent(pcb, tcp_conn_sent);
455
-	tcp_recv(pcb, tcp_conn_recv);
456
-	tcp_poll(pcb, tcp_conn_poll, POLL_TIME_S * 2);
457
-	tcp_err(pcb, tcp_conn_err);
458
-}
459
-
460
-static err_t tcp_server_accept(void *arg, struct tcp_pcb *client_pcb, err_t err)
461
-{
462
-	struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
463
-
464
-	if (err != ERR_OK || client_pcb == NULL) {
465
-		DEBUG_printf("Failure in accept\n");
466
-		tcp_server_complete(ctx, err);
467
-		return ERR_VAL;
468
-	}
469
-	DEBUG_printf("Connection opened\n");
470
-
471
-	if (ctx->conn_pcb) {
472
-		DEBUG_printf("Already have a connection\n");
473
-		tcp_abort(client_pcb);
474
-		return ERR_ABRT;
475
-	}
476
-
477
-	tcp_conn_init(ctx, client_pcb);
478
-
479
-	return ERR_OK;
480
-}
481
-
482
-static err_t tcp_server_listen(struct tcp_comm_ctx *ctx)
483
-{
484
-	DEBUG_printf("Starting server at %s on port %u\n", ip4addr_ntoa(netif_ip4_addr(netif_list)), TCP_PORT);
485
-
486
-	ctx->serv_done = false;
487
-
488
-	struct tcp_pcb *pcb = tcp_new_ip_type(IPADDR_TYPE_ANY);
489
-	if (!pcb) {
490
-		DEBUG_printf("failed to create pcb\n");
491
-		return ERR_MEM;
492
-	}
493
-
494
-	err_t err = tcp_bind(pcb, NULL, TCP_PORT);
495
-	if (err) {
496
-		DEBUG_printf("failed to bind to port %d\n", TCP_PORT);
497
-		tcp_abort(pcb);
498
-		return err;
499
-	}
500
-
501
-	ctx->serv_pcb = tcp_listen_with_backlog_and_err(pcb, 1, &err);
502
-	if (!ctx->serv_pcb) {
503
-		DEBUG_printf("failed to listen: %d\n", err);
504
-		return err;
505
-	}
506
-
507
-	tcp_arg(ctx->serv_pcb, ctx);
508
-	tcp_accept(ctx->serv_pcb, tcp_server_accept);
509
-
510
-	return ERR_OK;
511
-}
512
-
513 28
 static uint32_t handle_sync(uint32_t *args_in, uint8_t *data_in, uint32_t *resp_args_out, uint8_t *resp_data_out)
514 29
 {
515 30
 	return RSP_SYNC;
@@ -523,21 +38,6 @@ const struct comm_command util_sync_cmd = {
523 38
 	.handle = &handle_sync,
524 39
 };
525 40
 
526
-static void tcp_comm_init(struct tcp_comm_ctx *ctx, const struct comm_command *const *cmds,
527
-		unsigned int n_cmds, uint32_t sync_opcode)
528
-{
529
-	unsigned int i;
530
-	for (i = 0; i < n_cmds; i++) {
531
-		assert(cmds[i]->nargs <= MAX_NARG);
532
-		assert(cmds[i]->resp_nargs <= MAX_NARG);
533
-	}
534
-
535
-	memset(ctx, 0, sizeof(*ctx));
536
-	ctx->cmds = cmds;
537
-	ctx->n_cmds = n_cmds;
538
-	ctx->sync_opcode = sync_opcode;
539
-}
540
-
541 41
 int main()
542 42
 {
543 43
 	stdio_init_all();
@@ -559,22 +59,21 @@ int main()
559 59
 		printf("Connected.\n");
560 60
 	}
561 61
 
562
-	struct tcp_comm_ctx tcp;
563 62
 	const struct comm_command *cmds[] = {
564 63
 		&util_sync_cmd,
565 64
 	};
566 65
 
567
-	tcp_comm_init(&tcp, cmds, 1, CMD_SYNC);
66
+	struct tcp_comm_ctx *tcp = tcp_comm_new(cmds, 1, CMD_SYNC);
568 67
 
569 68
 	for ( ; ; ) {
570
-		err_t err = tcp_server_listen(&tcp);
69
+		err_t err = tcp_comm_listen(tcp, TCP_PORT);
571 70
 		if (err != ERR_OK) {
572 71
 			printf("Failed to start server: %d\n", err);
573 72
 			sleep_ms(1000);
574 73
 			continue;
575 74
 		}
576 75
 
577
-		while (!tcp.serv_done) {
76
+		while (!tcp_comm_server_done(tcp)) {
578 77
 			cyw43_arch_poll();
579 78
 			sleep_ms(10);
580 79
 		}

+ 525
- 0
tcp_comm.c View File

@@ -0,0 +1,525 @@
1
+/**
2
+ * Copyright (c) 2022 Brian Starkey <stark3y@gmail.com>
3
+ *
4
+ * Parts based on the Pico W tcp_server example:
5
+ * Copyright (c) 2022 Raspberry Pi (Trading) Ltd.
6
+ *
7
+ * SPDX-License-Identifier: BSD-3-Clause
8
+ */
9
+#include <stdlib.h>
10
+
11
+#include "pico/cyw43_arch.h"
12
+
13
+#include "lwip/pbuf.h"
14
+#include "lwip/tcp.h"
15
+
16
+#include "tcp_comm.h"
17
+
18
+#define DEBUG_printf printf
19
+#define POLL_TIME_S 5
20
+
21
+#define COMM_MAX_NARG     5
22
+#define COMM_MAX_DATA_LEN 1024
23
+
24
+#define COMM_RSP_OK       (('O' << 0) | ('K' << 8) | ('O' << 16) | ('K' << 24))
25
+#define COMM_RSP_ERR      (('E' << 0) | ('R' << 8) | ('R' << 16) | ('!' << 24))
26
+
27
+enum conn_state {
28
+	CONN_STATE_WAIT_FOR_SYNC,
29
+	CONN_STATE_READ_OPCODE,
30
+	CONN_STATE_READ_ARGS,
31
+	CONN_STATE_READ_DATA,
32
+	CONN_STATE_HANDLE,
33
+	CONN_STATE_WRITE_RESP,
34
+	CONN_STATE_WRITE_ERROR,
35
+	CONN_STATE_CLOSED,
36
+};
37
+
38
+struct tcp_comm_ctx {
39
+	struct tcp_pcb *serv_pcb;
40
+	volatile bool serv_done;
41
+	enum conn_state conn_state;
42
+
43
+	struct tcp_pcb *client_pcb;
44
+	uint8_t buf[(sizeof(uint32_t) * (1 + COMM_MAX_NARG)) + COMM_MAX_DATA_LEN];
45
+	uint16_t rx_bytes_received;
46
+	uint16_t rx_bytes_remaining;
47
+
48
+	uint16_t tx_bytes_sent;
49
+	uint16_t tx_bytes_remaining;
50
+
51
+	uint32_t resp_data_len;
52
+
53
+	const struct comm_command *cmd;
54
+	const struct comm_command *const *cmds;
55
+	unsigned int n_cmds;
56
+	uint32_t sync_opcode;
57
+};
58
+
59
+#define COMM_BUF_OPCODE(_buf)       ((uint32_t *)((uint8_t *)(_buf)))
60
+#define COMM_BUF_ARGS(_buf)         ((uint32_t *)((uint8_t *)(_buf) + sizeof(uint32_t)))
61
+#define COMM_BUF_BODY(_buf, _nargs) ((uint8_t *)(_buf) + (sizeof(uint32_t) * ((_nargs) + 1)))
62
+
63
+static const struct comm_command *find_command_desc(struct tcp_comm_ctx *ctx, uint32_t opcode)
64
+{
65
+	unsigned int i;
66
+
67
+	for (i = 0; i < ctx->n_cmds; i++) {
68
+		if (ctx->cmds[i]->opcode == opcode) {
69
+			return ctx->cmds[i];
70
+		}
71
+	}
72
+
73
+	return NULL;
74
+}
75
+
76
+static bool is_error(uint32_t status)
77
+{
78
+	return status == COMM_RSP_ERR;
79
+}
80
+
81
+static int tcp_comm_sync_begin(struct tcp_comm_ctx *ctx);
82
+static int tcp_comm_sync_complete(struct tcp_comm_ctx *ctx);
83
+static int tcp_comm_opcode_begin(struct tcp_comm_ctx *ctx);
84
+static int tcp_comm_opcode_complete(struct tcp_comm_ctx *ctx);
85
+static int tcp_comm_args_begin(struct tcp_comm_ctx *ctx);
86
+static int tcp_comm_args_complete(struct tcp_comm_ctx *ctx);
87
+static int tcp_comm_data_begin(struct tcp_comm_ctx *ctx, uint32_t data_len);
88
+static int tcp_comm_data_complete(struct tcp_comm_ctx *ctx);
89
+static int tcp_comm_response_begin(struct tcp_comm_ctx *ctx);
90
+static int tcp_comm_response_complete(struct tcp_comm_ctx *ctx);
91
+static int tcp_comm_error_begin(struct tcp_comm_ctx *ctx);
92
+
93
+static int tcp_comm_sync_begin(struct tcp_comm_ctx *ctx)
94
+{
95
+	ctx->conn_state = CONN_STATE_WAIT_FOR_SYNC;
96
+	ctx->rx_bytes_received = 0;
97
+	ctx->rx_bytes_remaining = sizeof(uint32_t);
98
+
99
+	DEBUG_printf("sync_begin %d\n", ctx->rx_bytes_remaining);
100
+}
101
+
102
+static int tcp_comm_sync_complete(struct tcp_comm_ctx *ctx)
103
+{
104
+	if (ctx->sync_opcode != *COMM_BUF_OPCODE(ctx->buf)) {
105
+		DEBUG_printf("sync not correct: %c%c%c%c\n", ctx->buf[0], ctx->buf[1], ctx->buf[2], ctx->buf[3]);
106
+		return tcp_comm_error_begin(ctx);
107
+	}
108
+
109
+	return tcp_comm_opcode_complete(ctx);
110
+}
111
+
112
+static int tcp_comm_opcode_begin(struct tcp_comm_ctx *ctx)
113
+{
114
+	ctx->conn_state = CONN_STATE_READ_OPCODE;
115
+	ctx->rx_bytes_received = 0;
116
+	ctx->rx_bytes_remaining = sizeof(uint32_t);
117
+
118
+	return 0;
119
+}
120
+
121
+static int tcp_comm_opcode_complete(struct tcp_comm_ctx *ctx)
122
+{
123
+	ctx->cmd = find_command_desc(ctx, *COMM_BUF_OPCODE(ctx->buf));
124
+	if (!ctx->cmd) {
125
+		DEBUG_printf("no command for '%c%c%c%c'\n", ctx->buf[0], ctx->buf[1], ctx->buf[2], ctx->buf[3]);
126
+		return tcp_comm_error_begin(ctx);
127
+	} else {
128
+		DEBUG_printf("got command '%c%c%c%c'\n", ctx->buf[0], ctx->buf[1], ctx->buf[2], ctx->buf[3]);
129
+	}
130
+
131
+	return tcp_comm_args_begin(ctx);
132
+}
133
+
134
+static int tcp_comm_args_begin(struct tcp_comm_ctx *ctx)
135
+{
136
+	ctx->conn_state = CONN_STATE_READ_ARGS;
137
+	ctx->rx_bytes_received = 0;
138
+	ctx->rx_bytes_remaining = ctx->cmd->nargs * sizeof(uint32_t);
139
+
140
+	if (ctx->cmd->nargs == 0) {
141
+		return tcp_comm_args_complete(ctx);
142
+	}
143
+
144
+	return 0;
145
+}
146
+
147
+static int tcp_comm_args_complete(struct tcp_comm_ctx *ctx)
148
+{
149
+	const struct comm_command *cmd = ctx->cmd;
150
+
151
+	uint32_t data_len = 0;
152
+
153
+	if (cmd->size) {
154
+		uint32_t status = cmd->size(COMM_BUF_ARGS(ctx->buf),
155
+					    &data_len,
156
+					    &ctx->resp_data_len);
157
+		if (is_error(status)) {
158
+			return tcp_comm_error_begin(ctx);
159
+		}
160
+	}
161
+
162
+	return tcp_comm_data_begin(ctx, data_len);
163
+}
164
+
165
+static int tcp_comm_data_begin(struct tcp_comm_ctx *ctx, uint32_t data_len)
166
+{
167
+	const struct comm_command *cmd = ctx->cmd;
168
+
169
+	ctx->conn_state = CONN_STATE_READ_DATA;
170
+	ctx->rx_bytes_received = 0;
171
+	ctx->rx_bytes_remaining = data_len;
172
+
173
+	if (data_len == 0) {
174
+		return tcp_comm_data_complete(ctx);
175
+	}
176
+
177
+
178
+	return 0;
179
+}
180
+
181
+static int tcp_comm_data_complete(struct tcp_comm_ctx *ctx)
182
+{
183
+	const struct comm_command *cmd = ctx->cmd;
184
+
185
+	if (cmd->handle) {
186
+		uint32_t status = cmd->handle(COMM_BUF_ARGS(ctx->buf),
187
+					      COMM_BUF_BODY(ctx->buf, cmd->nargs),
188
+					      COMM_BUF_ARGS(ctx->buf),
189
+					      COMM_BUF_BODY(ctx->buf, cmd->resp_nargs));
190
+		if (is_error(status)) {
191
+			return tcp_comm_error_begin(ctx);
192
+		}
193
+
194
+		*COMM_BUF_OPCODE(ctx->buf) = status;
195
+	} else {
196
+		// TODO: Should we just assert(desc->handle)?
197
+		*COMM_BUF_OPCODE(ctx->buf) = COMM_RSP_OK;
198
+	}
199
+
200
+	return tcp_comm_response_begin(ctx);
201
+}
202
+
203
+static int tcp_comm_response_begin(struct tcp_comm_ctx *ctx)
204
+{
205
+	ctx->conn_state = CONN_STATE_WRITE_RESP;
206
+	ctx->tx_bytes_sent = 0;
207
+	ctx->tx_bytes_remaining = ctx->resp_data_len + ((ctx->cmd->resp_nargs + 1) * sizeof(uint32_t));
208
+
209
+	err_t err = tcp_write(ctx->client_pcb, ctx->buf, ctx->tx_bytes_remaining, 0);
210
+	if (err != ERR_OK) {
211
+		return -1;
212
+	}
213
+
214
+	return 0;
215
+}
216
+
217
+static int tcp_comm_error_begin(struct tcp_comm_ctx *ctx)
218
+{
219
+	ctx->conn_state = CONN_STATE_WRITE_ERROR;
220
+	ctx->tx_bytes_sent = 0;
221
+	ctx->tx_bytes_remaining = sizeof(uint32_t);
222
+
223
+	*COMM_BUF_OPCODE(ctx->buf) = COMM_RSP_ERR;
224
+
225
+	err_t err = tcp_write(ctx->client_pcb, ctx->buf, ctx->tx_bytes_remaining, 0);
226
+	if (err != ERR_OK) {
227
+		return -1;
228
+	}
229
+
230
+	return 0;
231
+}
232
+
233
+
234
+static int tcp_comm_response_complete(struct tcp_comm_ctx *ctx)
235
+{
236
+	return tcp_comm_opcode_begin(ctx);
237
+}
238
+
239
+static int tcp_comm_rx_complete(struct tcp_comm_ctx *ctx)
240
+{
241
+	switch (ctx->conn_state) {
242
+	case CONN_STATE_WAIT_FOR_SYNC:
243
+		return tcp_comm_sync_complete(ctx);
244
+	case CONN_STATE_READ_OPCODE:
245
+		return tcp_comm_opcode_complete(ctx);
246
+	case CONN_STATE_READ_ARGS:
247
+		return tcp_comm_args_complete(ctx);
248
+	case CONN_STATE_READ_DATA:
249
+		return tcp_comm_data_complete(ctx);
250
+	default:
251
+		return -1;
252
+	}
253
+}
254
+
255
+static int tcp_comm_tx_complete(struct tcp_comm_ctx *ctx)
256
+{
257
+	switch (ctx->conn_state) {
258
+	case CONN_STATE_WRITE_RESP:
259
+		return tcp_comm_response_complete(ctx);
260
+	case CONN_STATE_WRITE_ERROR:
261
+		return -1;
262
+	default:
263
+		return -1;
264
+	}
265
+}
266
+
267
+static err_t tcp_comm_client_close(struct tcp_comm_ctx *ctx)
268
+{
269
+	err_t err = ERR_OK;
270
+
271
+	cyw43_arch_gpio_put (0, false);
272
+	ctx->conn_state = CONN_STATE_CLOSED;
273
+
274
+	if (!ctx->client_pcb) {
275
+		return err;
276
+	}
277
+
278
+	tcp_arg(ctx->client_pcb, NULL);
279
+	tcp_poll(ctx->client_pcb, NULL, 0);
280
+	tcp_sent(ctx->client_pcb, NULL);
281
+	tcp_recv(ctx->client_pcb, NULL);
282
+	tcp_err(ctx->client_pcb, NULL);
283
+	err = tcp_close(ctx->client_pcb);
284
+	if (err != ERR_OK) {
285
+		DEBUG_printf("close failed %d, calling abort\n", err);
286
+		tcp_abort(ctx->client_pcb);
287
+		err = ERR_ABRT;
288
+	}
289
+
290
+	ctx->client_pcb = NULL;
291
+
292
+	return err;
293
+}
294
+
295
+err_t tcp_comm_server_close(struct tcp_comm_ctx *ctx)
296
+{
297
+	err_t err = ERR_OK;
298
+
299
+	err = tcp_comm_client_close(ctx);
300
+	if ((err != ERR_OK) && ctx->serv_pcb) {
301
+		tcp_arg(ctx->serv_pcb, NULL);
302
+		tcp_abort(ctx->serv_pcb);
303
+		ctx->serv_pcb = NULL;
304
+		return ERR_ABRT;
305
+	}
306
+
307
+	if (!ctx->serv_pcb) {
308
+		return err;
309
+	}
310
+
311
+	tcp_arg(ctx->serv_pcb, NULL);
312
+	err = tcp_close(ctx->serv_pcb);
313
+	if (err != ERR_OK) {
314
+		tcp_abort(ctx->serv_pcb);
315
+		err = ERR_ABRT;
316
+	}
317
+	ctx->serv_pcb = NULL;
318
+
319
+	return err;
320
+}
321
+
322
+static void tcp_comm_server_complete(void *arg, int status)
323
+{
324
+	struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
325
+	if (status == 0) {
326
+		DEBUG_printf("server completed normally\n");
327
+	} else {
328
+		DEBUG_printf("server error %d\n", status);
329
+	}
330
+
331
+	tcp_comm_server_close(ctx);
332
+	ctx->serv_done = true;
333
+}
334
+
335
+static err_t tcp_comm_client_complete(void *arg, int status)
336
+{
337
+	struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
338
+	if (status == 0) {
339
+		DEBUG_printf("conn completed normally\n");
340
+	} else {
341
+		DEBUG_printf("conn error %d\n", status);
342
+	}
343
+	return tcp_comm_client_close(ctx);
344
+}
345
+
346
+static err_t tcp_comm_client_sent(void *arg, struct tcp_pcb *tpcb, u16_t len)
347
+{
348
+	struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
349
+	DEBUG_printf("tcp_comm_server_sent %u\n", len);
350
+
351
+	cyw43_arch_lwip_check();
352
+	if (len > ctx->tx_bytes_remaining) {
353
+		DEBUG_printf("tx len %d > remaining %d\n", len, ctx->tx_bytes_remaining);
354
+		return tcp_comm_client_complete(ctx, ERR_ARG);
355
+	}
356
+
357
+	ctx->tx_bytes_remaining -= len;
358
+	ctx->tx_bytes_sent += len;
359
+
360
+	if (ctx->tx_bytes_remaining == 0) {
361
+		int res = tcp_comm_tx_complete(ctx);
362
+		if (res) {
363
+			return tcp_comm_client_complete(ctx, ERR_ARG);
364
+		}
365
+	}
366
+
367
+	return ERR_OK;
368
+}
369
+
370
+static err_t tcp_comm_client_recv(void *arg, struct tcp_pcb *tpcb, struct pbuf *p, err_t err)
371
+{
372
+	struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
373
+	if (!p) {
374
+		DEBUG_printf("no pbuf\n");
375
+		return tcp_comm_client_complete(ctx, 0);
376
+	}
377
+
378
+	// this method is callback from lwIP, so cyw43_arch_lwip_begin is not required, however you
379
+	// can use this method to cause an assertion in debug mode, if this method is called when
380
+	// cyw43_arch_lwip_begin IS needed
381
+	cyw43_arch_lwip_check();
382
+	if (p->tot_len > 0) {
383
+		DEBUG_printf("tcp_comm_server_recv %d err %d\n", p->tot_len, err);
384
+
385
+		size_t to_copy = p->tot_len > ctx->rx_bytes_remaining ? ctx->rx_bytes_remaining : p->tot_len;
386
+
387
+		// Receive the buffer
388
+		if (pbuf_copy_partial(p, ctx->buf + ctx->rx_bytes_received, to_copy, 0) != to_copy) {
389
+			DEBUG_printf("wrong copy len\n");
390
+			return tcp_comm_client_complete(ctx, ERR_ARG);
391
+		}
392
+
393
+		ctx->rx_bytes_received += to_copy;
394
+		ctx->rx_bytes_remaining -= to_copy;
395
+		tcp_recved(tpcb, p->tot_len);
396
+
397
+		if (ctx->rx_bytes_remaining == 0) {
398
+			int res = tcp_comm_rx_complete(ctx);
399
+			if (res) {
400
+				return tcp_comm_client_complete(ctx, ERR_ARG);
401
+			}
402
+		}
403
+	}
404
+	pbuf_free(p);
405
+
406
+	return ERR_OK;
407
+}
408
+
409
+static err_t tcp_comm_client_poll(void *arg, struct tcp_pcb *tpcb)
410
+{
411
+	DEBUG_printf("tcp_comm_server_poll_fn\n");
412
+	return ERR_OK;
413
+}
414
+
415
+static void tcp_comm_client_err(void *arg, err_t err)
416
+{
417
+	struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
418
+
419
+	DEBUG_printf("tcp_comm_err %d\n", err);
420
+
421
+	ctx->client_pcb = NULL;
422
+	ctx->conn_state = CONN_STATE_CLOSED;
423
+	ctx->rx_bytes_remaining = 0;
424
+	cyw43_arch_gpio_put (0, false);
425
+}
426
+
427
+static void tcp_comm_client_init(struct tcp_comm_ctx *ctx, struct tcp_pcb *pcb)
428
+{
429
+	ctx->client_pcb = pcb;
430
+	tcp_arg(pcb, ctx);
431
+
432
+	cyw43_arch_gpio_put (0, true);
433
+
434
+	tcp_comm_sync_begin(ctx);
435
+
436
+	tcp_sent(pcb, tcp_comm_client_sent);
437
+	tcp_recv(pcb, tcp_comm_client_recv);
438
+	tcp_poll(pcb, tcp_comm_client_poll, POLL_TIME_S * 2);
439
+	tcp_err(pcb, tcp_comm_client_err);
440
+}
441
+
442
+static err_t tcp_comm_server_accept(void *arg, struct tcp_pcb *client_pcb, err_t err)
443
+{
444
+	struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
445
+
446
+	if (err != ERR_OK || client_pcb == NULL) {
447
+		DEBUG_printf("Failure in accept\n");
448
+		tcp_comm_server_complete(ctx, err);
449
+		return ERR_VAL;
450
+	}
451
+	DEBUG_printf("Connection opened\n");
452
+
453
+	if (ctx->client_pcb) {
454
+		DEBUG_printf("Already have a connection\n");
455
+		tcp_abort(client_pcb);
456
+		return ERR_ABRT;
457
+	}
458
+
459
+	tcp_comm_client_init(ctx, client_pcb);
460
+
461
+	return ERR_OK;
462
+}
463
+
464
+err_t tcp_comm_listen(struct tcp_comm_ctx *ctx, uint16_t port)
465
+{
466
+	DEBUG_printf("Starting server at %s on port %u\n", ip4addr_ntoa(netif_ip4_addr(netif_list)), port);
467
+
468
+	ctx->serv_done = false;
469
+
470
+	struct tcp_pcb *pcb = tcp_new_ip_type(IPADDR_TYPE_ANY);
471
+	if (!pcb) {
472
+		DEBUG_printf("failed to create pcb\n");
473
+		return ERR_MEM;
474
+	}
475
+
476
+	err_t err = tcp_bind(pcb, NULL, port);
477
+	if (err) {
478
+		DEBUG_printf("failed to bind to port %d\n", port);
479
+		tcp_abort(pcb);
480
+		return err;
481
+	}
482
+
483
+	ctx->serv_pcb = tcp_listen_with_backlog_and_err(pcb, 1, &err);
484
+	if (!ctx->serv_pcb) {
485
+		DEBUG_printf("failed to listen: %d\n", err);
486
+		return err;
487
+	}
488
+
489
+	tcp_arg(ctx->serv_pcb, ctx);
490
+	tcp_accept(ctx->serv_pcb, tcp_comm_server_accept);
491
+
492
+	return ERR_OK;
493
+}
494
+
495
+struct tcp_comm_ctx *tcp_comm_new(const struct comm_command *const *cmds,
496
+		unsigned int n_cmds, uint32_t sync_opcode)
497
+{
498
+	struct tcp_comm_ctx *ctx = calloc(1, sizeof(struct tcp_comm_ctx));
499
+	if (!ctx) {
500
+		return NULL;
501
+	}
502
+
503
+	unsigned int i;
504
+	for (i = 0; i < n_cmds; i++) {
505
+		assert(cmds[i]->nargs <= MAX_NARG);
506
+		assert(cmds[i]->resp_nargs <= MAX_NARG);
507
+	}
508
+
509
+	ctx->cmds = cmds;
510
+	ctx->n_cmds = n_cmds;
511
+	ctx->sync_opcode = sync_opcode;
512
+
513
+	return ctx;
514
+}
515
+
516
+void tcp_comm_delete(struct tcp_comm_ctx *ctx)
517
+{
518
+	tcp_comm_server_close(ctx);
519
+	free(ctx);
520
+}
521
+
522
+bool tcp_comm_server_done(struct tcp_comm_ctx *ctx)
523
+{
524
+	return ctx->serv_done;
525
+}

+ 30
- 0
tcp_comm.h View File

@@ -0,0 +1,30 @@
1
+/**
2
+ * Copyright (c) 2022 Brian Starkey <stark3y@gmail.com>
3
+ *
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ */
6
+#ifndef __TCP_COMM_H__
7
+#define __TCP_COMM_H__
8
+
9
+#include <stdint.h>
10
+#include <stdbool.h>
11
+
12
+struct comm_command {
13
+	uint32_t opcode;
14
+	uint32_t nargs;
15
+	uint32_t resp_nargs;
16
+	uint32_t (*size)(uint32_t *args_in, uint32_t *data_len_out, uint32_t *resp_data_len_out);
17
+	uint32_t (*handle)(uint32_t *args_in, uint8_t *data_in, uint32_t *resp_args_out, uint8_t *resp_data_out);
18
+};
19
+
20
+struct tcp_comm_ctx;
21
+
22
+err_t tcp_comm_listen(struct tcp_comm_ctx *ctx, uint16_t port);
23
+err_t tcp_comm_server_close(struct tcp_comm_ctx *ctx);
24
+bool tcp_comm_server_done(struct tcp_comm_ctx *ctx);
25
+
26
+struct tcp_comm_ctx *tcp_comm_new(const struct comm_command *const *cmds,
27
+		unsigned int n_cmds, uint32_t sync_opcode);
28
+void tcp_comm_delete(struct tcp_comm_ctx *ctx);
29
+
30
+#endif /* __TCP_COMM_H__ */

Loading…
Cancel
Save