|
@@ -0,0 +1,434 @@
|
|
1
|
+#
|
|
2
|
+# MarlinBinaryProtocol.py
|
|
3
|
+# Supporting Firmware upload via USB/Serial, saving to the attached media.
|
|
4
|
+#
|
|
5
|
+import serial
|
|
6
|
+import math
|
|
7
|
+import time
|
|
8
|
+from collections import deque
|
|
9
|
+import threading
|
|
10
|
+import sys
|
|
11
|
+import datetime
|
|
12
|
+import random
|
|
13
|
+try:
|
|
14
|
+ import heatshrink
|
|
15
|
+ heatshrink_exists = True
|
|
16
|
+except ImportError:
|
|
17
|
+ heatshrink_exists = False
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+def millis():
|
|
21
|
+ return time.perf_counter() * 1000
|
|
22
|
+
|
|
23
|
+class TimeOut(object):
|
|
24
|
+ def __init__(self, milliseconds):
|
|
25
|
+ self.duration = milliseconds
|
|
26
|
+ self.reset()
|
|
27
|
+
|
|
28
|
+ def reset(self):
|
|
29
|
+ self.endtime = millis() + self.duration
|
|
30
|
+
|
|
31
|
+ def timedout(self):
|
|
32
|
+ return millis() > self.endtime
|
|
33
|
+
|
|
34
|
+class ReadTimeout(Exception):
|
|
35
|
+ pass
|
|
36
|
+class FatalError(Exception):
|
|
37
|
+ pass
|
|
38
|
+class SycronisationError(Exception):
|
|
39
|
+ pass
|
|
40
|
+class PayloadOverflow(Exception):
|
|
41
|
+ pass
|
|
42
|
+class ConnectionLost(Exception):
|
|
43
|
+ pass
|
|
44
|
+
|
|
45
|
+class Protocol(object):
|
|
46
|
+ device = None
|
|
47
|
+ baud = None
|
|
48
|
+ max_block_size = 0
|
|
49
|
+ port = None
|
|
50
|
+ block_size = 0
|
|
51
|
+
|
|
52
|
+ packet_transit = None
|
|
53
|
+ packet_status = None
|
|
54
|
+ packet_ping = None
|
|
55
|
+
|
|
56
|
+ errors = 0
|
|
57
|
+ packet_buffer = None
|
|
58
|
+ simulate_errors = 0
|
|
59
|
+ sync = 0
|
|
60
|
+ connected = False
|
|
61
|
+ syncronised = False
|
|
62
|
+ worker_thread = None
|
|
63
|
+
|
|
64
|
+ response_timeout = 1000
|
|
65
|
+
|
|
66
|
+ applications = []
|
|
67
|
+ responses = deque()
|
|
68
|
+
|
|
69
|
+ def __init__(self, device, baud, bsize, simerr, timeout):
|
|
70
|
+ print("pySerial Version:", serial.VERSION)
|
|
71
|
+ self.port = serial.Serial(device, baudrate = baud, write_timeout = 0, timeout = 1)
|
|
72
|
+ self.device = device
|
|
73
|
+ self.baud = baud
|
|
74
|
+ self.block_size = int(bsize)
|
|
75
|
+ self.simulate_errors = max(min(simerr, 1.0), 0.0);
|
|
76
|
+ self.connected = True
|
|
77
|
+ self.response_timeout = timeout
|
|
78
|
+
|
|
79
|
+ self.register(['ok', 'rs', 'ss', 'fe'], self.process_input)
|
|
80
|
+
|
|
81
|
+ self.worker_thread = threading.Thread(target=Protocol.receive_worker, args=(self,))
|
|
82
|
+ self.worker_thread.start()
|
|
83
|
+
|
|
84
|
+ def receive_worker(self):
|
|
85
|
+ while self.port.in_waiting:
|
|
86
|
+ self.port.reset_input_buffer()
|
|
87
|
+
|
|
88
|
+ def dispatch(data):
|
|
89
|
+ for tokens, callback in self.applications:
|
|
90
|
+ for token in tokens:
|
|
91
|
+ if token == data[:len(token)]:
|
|
92
|
+ callback((token, data[len(token):]))
|
|
93
|
+ return
|
|
94
|
+
|
|
95
|
+ def reconnect():
|
|
96
|
+ print("Reconnecting..")
|
|
97
|
+ self.port.close()
|
|
98
|
+ for x in range(10):
|
|
99
|
+ try:
|
|
100
|
+ if self.connected:
|
|
101
|
+ self.port = serial.Serial(self.device, baudrate = self.baud, write_timeout = 0, timeout = 1)
|
|
102
|
+ return
|
|
103
|
+ else:
|
|
104
|
+ print("Connection closed")
|
|
105
|
+ return
|
|
106
|
+ except:
|
|
107
|
+ time.sleep(1)
|
|
108
|
+ raise ConnectionLost()
|
|
109
|
+
|
|
110
|
+ while self.connected:
|
|
111
|
+ try:
|
|
112
|
+ data = self.port.readline().decode('utf8').rstrip()
|
|
113
|
+ if len(data):
|
|
114
|
+ #print(data)
|
|
115
|
+ dispatch(data)
|
|
116
|
+ except OSError:
|
|
117
|
+ reconnect()
|
|
118
|
+ except UnicodeDecodeError:
|
|
119
|
+ # dodgy client output or datastream corruption
|
|
120
|
+ self.port.reset_input_buffer()
|
|
121
|
+
|
|
122
|
+ def shutdown(self):
|
|
123
|
+ self.connected = False
|
|
124
|
+ self.worker_thread.join()
|
|
125
|
+ self.port.close()
|
|
126
|
+
|
|
127
|
+ def process_input(self, data):
|
|
128
|
+ #print(data)
|
|
129
|
+ self.responses.append(data)
|
|
130
|
+
|
|
131
|
+ def register(self, tokens, callback):
|
|
132
|
+ self.applications.append((tokens, callback))
|
|
133
|
+
|
|
134
|
+ def send(self, protocol, packet_type, data = bytearray()):
|
|
135
|
+ self.packet_transit = self.build_packet(protocol, packet_type, data)
|
|
136
|
+ self.packet_status = 0
|
|
137
|
+ self.transmit_attempt = 0
|
|
138
|
+
|
|
139
|
+ timeout = TimeOut(self.response_timeout * 20)
|
|
140
|
+ while self.packet_status == 0:
|
|
141
|
+ try:
|
|
142
|
+ if timeout.timedout():
|
|
143
|
+ raise ConnectionLost()
|
|
144
|
+ self.transmit_packet(self.packet_transit)
|
|
145
|
+ self.await_response()
|
|
146
|
+ except ReadTimeout:
|
|
147
|
+ self.errors += 1
|
|
148
|
+ #print("Packetloss detected..")
|
|
149
|
+ self.packet_transit = None
|
|
150
|
+
|
|
151
|
+ def await_response(self):
|
|
152
|
+ timeout = TimeOut(self.response_timeout)
|
|
153
|
+ while not len(self.responses):
|
|
154
|
+ time.sleep(0.00001)
|
|
155
|
+ if timeout.timedout():
|
|
156
|
+ raise ReadTimeout()
|
|
157
|
+
|
|
158
|
+ while len(self.responses):
|
|
159
|
+ token, data = self.responses.popleft()
|
|
160
|
+ switch = {'ok' : self.response_ok, 'rs': self.response_resend, 'ss' : self.response_stream_sync, 'fe' : self.response_fatal_error}
|
|
161
|
+ switch[token](data)
|
|
162
|
+
|
|
163
|
+ def send_ascii(self, data, send_and_forget = False):
|
|
164
|
+ self.packet_transit = bytearray(data, "utf8") + b'\n'
|
|
165
|
+ self.packet_status = 0
|
|
166
|
+ self.transmit_attempt = 0
|
|
167
|
+
|
|
168
|
+ timeout = TimeOut(self.response_timeout * 20)
|
|
169
|
+ while self.packet_status == 0:
|
|
170
|
+ try:
|
|
171
|
+ if timeout.timedout():
|
|
172
|
+ return
|
|
173
|
+ self.port.write(self.packet_transit)
|
|
174
|
+ if send_and_forget:
|
|
175
|
+ self.packet_status = 1
|
|
176
|
+ else:
|
|
177
|
+ self.await_response_ascii()
|
|
178
|
+ except ReadTimeout:
|
|
179
|
+ self.errors += 1
|
|
180
|
+ #print("Packetloss detected..")
|
|
181
|
+ except serial.serialutil.SerialException:
|
|
182
|
+ return
|
|
183
|
+ self.packet_transit = None
|
|
184
|
+
|
|
185
|
+ def await_response_ascii(self):
|
|
186
|
+ timeout = TimeOut(self.response_timeout)
|
|
187
|
+ while not len(self.responses):
|
|
188
|
+ time.sleep(0.00001)
|
|
189
|
+ if timeout.timedout():
|
|
190
|
+ raise ReadTimeout()
|
|
191
|
+ token, data = self.responses.popleft()
|
|
192
|
+ self.packet_status = 1
|
|
193
|
+
|
|
194
|
+ def corrupt_array(self, data):
|
|
195
|
+ rid = random.randint(0, len(data) - 1)
|
|
196
|
+ data[rid] ^= 0xAA
|
|
197
|
+ return data
|
|
198
|
+
|
|
199
|
+ def transmit_packet(self, packet):
|
|
200
|
+ packet = bytearray(packet)
|
|
201
|
+ if(self.simulate_errors > 0 and random.random() > (1.0 - self.simulate_errors)):
|
|
202
|
+ if random.random() > 0.9:
|
|
203
|
+ #random data drop
|
|
204
|
+ start = random.randint(0, len(packet))
|
|
205
|
+ end = start + random.randint(1, 10)
|
|
206
|
+ packet = packet[:start] + packet[end:]
|
|
207
|
+ #print("Dropping {0} bytes".format(end - start))
|
|
208
|
+ else:
|
|
209
|
+ #random corruption
|
|
210
|
+ packet = self.corrupt_array(packet)
|
|
211
|
+ #print("Single byte corruption")
|
|
212
|
+ self.port.write(packet)
|
|
213
|
+ self.transmit_attempt += 1
|
|
214
|
+
|
|
215
|
+ def build_packet(self, protocol, packet_type, data = bytearray()):
|
|
216
|
+ PACKET_TOKEN = 0xB5AD
|
|
217
|
+
|
|
218
|
+ if len(data) > self.max_block_size:
|
|
219
|
+ raise PayloadOverflow()
|
|
220
|
+
|
|
221
|
+ packet_buffer = bytearray()
|
|
222
|
+
|
|
223
|
+ packet_buffer += self.pack_int8(self.sync) # 8bit sync id
|
|
224
|
+ packet_buffer += self.pack_int4_2(protocol, packet_type) # 4 bit protocol id, 4 bit packet type
|
|
225
|
+ packet_buffer += self.pack_int16(len(data)) # 16bit packet length
|
|
226
|
+ packet_buffer += self.pack_int16(self.build_checksum(packet_buffer)) # 16bit header checksum
|
|
227
|
+
|
|
228
|
+ if len(data):
|
|
229
|
+ packet_buffer += data
|
|
230
|
+ packet_buffer += self.pack_int16(self.build_checksum(packet_buffer))
|
|
231
|
+
|
|
232
|
+ packet_buffer = self.pack_int16(PACKET_TOKEN) + packet_buffer # 16bit start token, not included in checksum
|
|
233
|
+ return packet_buffer
|
|
234
|
+
|
|
235
|
+ # checksum 16 fletchers
|
|
236
|
+ def checksum(self, cs, value):
|
|
237
|
+ cs_low = (((cs & 0xFF) + value) % 255);
|
|
238
|
+ return ((((cs >> 8) + cs_low) % 255) << 8) | cs_low;
|
|
239
|
+
|
|
240
|
+ def build_checksum(self, buffer):
|
|
241
|
+ cs = 0
|
|
242
|
+ for b in buffer:
|
|
243
|
+ cs = self.checksum(cs, b)
|
|
244
|
+ return cs
|
|
245
|
+
|
|
246
|
+ def pack_int32(self, value):
|
|
247
|
+ return value.to_bytes(4, byteorder='little')
|
|
248
|
+
|
|
249
|
+ def pack_int16(self, value):
|
|
250
|
+ return value.to_bytes(2, byteorder='little')
|
|
251
|
+
|
|
252
|
+ def pack_int8(self, value):
|
|
253
|
+ return value.to_bytes(1, byteorder='little')
|
|
254
|
+
|
|
255
|
+ def pack_int4_2(self, vh, vl):
|
|
256
|
+ value = ((vh & 0xF) << 4) | (vl & 0xF)
|
|
257
|
+ return value.to_bytes(1, byteorder='little')
|
|
258
|
+
|
|
259
|
+ def connect(self):
|
|
260
|
+ print("Connecting: Switching Marlin to Binary Protocol...")
|
|
261
|
+ self.send_ascii("M28B1")
|
|
262
|
+ self.send(0, 1)
|
|
263
|
+
|
|
264
|
+ def disconnect(self):
|
|
265
|
+ self.send(0, 2)
|
|
266
|
+ self.syncronised = False
|
|
267
|
+
|
|
268
|
+ def response_ok(self, data):
|
|
269
|
+ try:
|
|
270
|
+ packet_id = int(data);
|
|
271
|
+ except ValueError:
|
|
272
|
+ return
|
|
273
|
+ if packet_id != self.sync:
|
|
274
|
+ raise SycronisationError()
|
|
275
|
+ self.sync = (self.sync + 1) % 256
|
|
276
|
+ self.packet_status = 1
|
|
277
|
+
|
|
278
|
+ def response_resend(self, data):
|
|
279
|
+ packet_id = int(data);
|
|
280
|
+ self.errors += 1
|
|
281
|
+ if not self.syncronised:
|
|
282
|
+ print("Retrying syncronisation")
|
|
283
|
+ elif packet_id != self.sync:
|
|
284
|
+ raise SycronisationError()
|
|
285
|
+
|
|
286
|
+ def response_stream_sync(self, data):
|
|
287
|
+ sync, max_block_size, protocol_version = data.split(',')
|
|
288
|
+ self.sync = int(sync)
|
|
289
|
+ self.max_block_size = int(max_block_size)
|
|
290
|
+ self.block_size = self.max_block_size if self.max_block_size < self.block_size else self.block_size
|
|
291
|
+ self.protocol_version = protocol_version
|
|
292
|
+ self.packet_status = 1
|
|
293
|
+ self.syncronised = True
|
|
294
|
+ print("Connection synced [{0}], binary protocol version {1}, {2} byte payload buffer".format(self.sync, self.protocol_version, self.max_block_size))
|
|
295
|
+
|
|
296
|
+ def response_fatal_error(self, data):
|
|
297
|
+ raise FatalError()
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+class FileTransferProtocol(object):
|
|
301
|
+ protocol_id = 1
|
|
302
|
+
|
|
303
|
+ class Packet(object):
|
|
304
|
+ QUERY = 0
|
|
305
|
+ OPEN = 1
|
|
306
|
+ CLOSE = 2
|
|
307
|
+ WRITE = 3
|
|
308
|
+ ABORT = 4
|
|
309
|
+
|
|
310
|
+ responses = deque()
|
|
311
|
+ def __init__(self, protocol, timeout = None):
|
|
312
|
+ protocol.register(['PFT:success', 'PFT:version:', 'PFT:fail', 'PFT:busy', 'PFT:ioerror', 'PTF:invalid'], self.process_input)
|
|
313
|
+ self.protocol = protocol
|
|
314
|
+ self.response_timeout = timeout or protocol.response_timeout
|
|
315
|
+
|
|
316
|
+ def process_input(self, data):
|
|
317
|
+ #print(data)
|
|
318
|
+ self.responses.append(data)
|
|
319
|
+
|
|
320
|
+ def await_response(self, timeout = None):
|
|
321
|
+ timeout = TimeOut(timeout or self.response_timeout)
|
|
322
|
+ while not len(self.responses):
|
|
323
|
+ time.sleep(0.0001)
|
|
324
|
+ if timeout.timedout():
|
|
325
|
+ raise ReadTimeout()
|
|
326
|
+
|
|
327
|
+ return self.responses.popleft()
|
|
328
|
+
|
|
329
|
+ def connect(self):
|
|
330
|
+ self.protocol.send(FileTransferProtocol.protocol_id, FileTransferProtocol.Packet.QUERY);
|
|
331
|
+
|
|
332
|
+ token, data = self.await_response()
|
|
333
|
+ if token != 'PFT:version:':
|
|
334
|
+ return False
|
|
335
|
+
|
|
336
|
+ self.version, _, compression = data.split(':')
|
|
337
|
+ if compression != 'none':
|
|
338
|
+ algorithm, window, lookahead = compression.split(',')
|
|
339
|
+ self.compression = {'algorithm': algorithm, 'window': int(window), 'lookahead': int(lookahead)}
|
|
340
|
+ else:
|
|
341
|
+ self.compression = {'algorithm': 'none'}
|
|
342
|
+
|
|
343
|
+ print("File Transfer version: {0}, compression: {1}".format(self.version, self.compression['algorithm']))
|
|
344
|
+
|
|
345
|
+ def open(self, filename, compression, dummy):
|
|
346
|
+ payload = b'\1' if dummy else b'\0' # dummy transfer
|
|
347
|
+ payload += b'\1' if compression else b'\0' # payload compression
|
|
348
|
+ payload += bytearray(filename, 'utf8') + b'\0'# target filename + null terminator
|
|
349
|
+
|
|
350
|
+ timeout = TimeOut(5000)
|
|
351
|
+ token = None
|
|
352
|
+ self.protocol.send(FileTransferProtocol.protocol_id, FileTransferProtocol.Packet.OPEN, payload);
|
|
353
|
+ while token != 'PFT:success' and not timeout.timedout():
|
|
354
|
+ try:
|
|
355
|
+ token, data = self.await_response(1000)
|
|
356
|
+ if token == 'PFT:success':
|
|
357
|
+ print(filename,"opened")
|
|
358
|
+ return
|
|
359
|
+ elif token == 'PFT:busy':
|
|
360
|
+ print("Broken transfer detected, purging")
|
|
361
|
+ self.abort()
|
|
362
|
+ time.sleep(0.1)
|
|
363
|
+ self.protocol.send(FileTransferProtocol.protocol_id, FileTransferProtocol.Packet.OPEN, payload);
|
|
364
|
+ timeout.reset()
|
|
365
|
+ elif token == 'PFT:fail':
|
|
366
|
+ raise Exception("Can not open file on client")
|
|
367
|
+ except ReadTimeout:
|
|
368
|
+ pass
|
|
369
|
+ raise ReadTimeout()
|
|
370
|
+
|
|
371
|
+ def write(self, data):
|
|
372
|
+ self.protocol.send(FileTransferProtocol.protocol_id, FileTransferProtocol.Packet.WRITE, data);
|
|
373
|
+
|
|
374
|
+ def close(self):
|
|
375
|
+ self.protocol.send(FileTransferProtocol.protocol_id, FileTransferProtocol.Packet.CLOSE);
|
|
376
|
+ token, data = self.await_response(1000)
|
|
377
|
+ if token == 'PFT:success':
|
|
378
|
+ print("File closed")
|
|
379
|
+ return
|
|
380
|
+ elif token == 'PFT:ioerror':
|
|
381
|
+ print("Client storage device IO error")
|
|
382
|
+ elif token == 'PFT:invalid':
|
|
383
|
+ print("No open file")
|
|
384
|
+
|
|
385
|
+ def abort(self):
|
|
386
|
+ self.protocol.send(FileTransferProtocol.protocol_id, FileTransferProtocol.Packet.ABORT);
|
|
387
|
+ token, data = self.await_response()
|
|
388
|
+ if token == 'PFT:success':
|
|
389
|
+ print("Transfer Aborted")
|
|
390
|
+
|
|
391
|
+ def copy(self, filename, dest_filename, compression, dummy):
|
|
392
|
+ self.connect()
|
|
393
|
+
|
|
394
|
+ compression_support = heatshrink_exists and self.compression['algorithm'] == 'heatshrink' and compression
|
|
395
|
+ if compression and (not heatshrink_exists or not self.compression['algorithm'] == 'heatshrink'):
|
|
396
|
+ print("Compression not supported by client")
|
|
397
|
+ #compression_support = False
|
|
398
|
+
|
|
399
|
+ data = open(filename, "rb").read()
|
|
400
|
+ filesize = len(data)
|
|
401
|
+
|
|
402
|
+ self.open(dest_filename, compression_support, dummy)
|
|
403
|
+
|
|
404
|
+ block_size = self.protocol.block_size
|
|
405
|
+ if compression_support:
|
|
406
|
+ data = heatshrink.encode(data, window_sz2=self.compression['window'], lookahead_sz2=self.compression['lookahead'])
|
|
407
|
+
|
|
408
|
+ cratio = filesize / len(data)
|
|
409
|
+
|
|
410
|
+ blocks = math.floor((len(data) + block_size - 1) / block_size)
|
|
411
|
+ kibs = 0
|
|
412
|
+ dump_pctg = 0
|
|
413
|
+ start_time = millis()
|
|
414
|
+ for i in range(blocks):
|
|
415
|
+ start = block_size * i
|
|
416
|
+ end = start + block_size
|
|
417
|
+ self.write(data[start:end])
|
|
418
|
+ kibs = (( (i+1) * block_size) / 1024) / (millis() + 1 - start_time) * 1000
|
|
419
|
+ if (i / blocks) >= dump_pctg:
|
|
420
|
+ print("\r{0:2.2f}% {1:4.2f}KiB/s {2} Errors: {3}".format((i / blocks) * 100, kibs, "[{0:4.2f}KiB/s]".format(kibs * cratio) if compression_support else "", self.protocol.errors), end='')
|
|
421
|
+ dump_pctg += 0.1
|
|
422
|
+ print("\r{0:2.2f}% {1:4.2f}KiB/s {2} Errors: {3}".format(100, kibs, "[{0:4.2f}KiB/s]".format(kibs * cratio) if compression_support else "", self.protocol.errors)) # no one likes transfers finishing at 99.8%
|
|
423
|
+
|
|
424
|
+ self.close()
|
|
425
|
+ print("Transfer complete")
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+class EchoProtocol(object):
|
|
429
|
+ def __init__(self, protocol):
|
|
430
|
+ protocol.register(['echo:'], self.process_input)
|
|
431
|
+ self.protocol = protocol
|
|
432
|
+
|
|
433
|
+ def process_input(self, data):
|
|
434
|
+ print(data)
|