v2 / vlib / compress / lz / interop / lz77_ref.py
141 lines · 122 sloc · 3.95 KB · 3c1a78c175a6f8e3cfc26d862ad81d2309f9a507
Raw
1#!/usr/bin/env python3
2import sys
3import time
4
5MIN_MATCH = 3
6MAX_LITERAL = 128
7STREAM_MAGIC = b'VLZ1'
8FORMAT_LZ77 = 0
9
10
11def encode_uvarint(value: int) -> bytes:
12 out = bytearray()
13 v = value
14 while v >= 0x80:
15 out.append((v & 0x7F) | 0x80)
16 v >>= 7
17 out.append(v)
18 return bytes(out)
19
20
21def decode_uvarint(data: bytes, pos: int):
22 value = 0
23 shift = 0
24 i = pos
25 while i < len(data) and shift <= 63:
26 b = data[i]
27 i += 1
28 value |= (b & 0x7F) << shift
29 if (b & 0x80) == 0:
30 return value, i
31 shift += 7
32 raise ValueError('bad length varint')
33
34
35def compress_lz77(data: bytes) -> bytes:
36 out = bytearray()
37 out.extend(STREAM_MAGIC)
38 out.append(FORMAT_LZ77)
39 out.extend(encode_uvarint(len(data)))
40 i = 0
41 while i < len(data):
42 lit_len = min(MAX_LITERAL, len(data) - i)
43 out.append(lit_len - 1)
44 out.extend(data[i : i + lit_len])
45 i += lit_len
46 return bytes(out)
47
48
49def decompress_lz77(data: bytes) -> bytes:
50 if len(data) < 6 or data[:4] != STREAM_MAGIC:
51 raise ValueError('bad magic')
52 if data[4] != FORMAT_LZ77:
53 raise ValueError('format mismatch')
54
55 expected_len, pos = decode_uvarint(data, 5)
56 out = bytearray()
57 while pos < len(data):
58 control = data[pos]
59 pos += 1
60 if (control & 0x80) == 0:
61 literal_len = (control & 0x7F) + 1
62 if pos + literal_len > len(data):
63 raise ValueError('truncated literal')
64 out.extend(data[pos : pos + literal_len])
65 pos += literal_len
66 else:
67 length = (control & 0x7F) + MIN_MATCH
68 off, pos = decode_uvarint(data, pos)
69 if off == 0 or off > len(out):
70 raise ValueError('bad offset')
71 base = len(out) - off
72 for k in range(length):
73 out.append(out[base + k])
74 if len(out) != expected_len:
75 raise ValueError('length mismatch')
76 return bytes(out)
77
78
79def main() -> int:
80 if len(sys.argv) < 2:
81 print(
82 f'usage:\n'
83 f' {sys.argv[0]} bench <input.bin> <iterations>\n'
84 f' {sys.argv[0]} compress <input.bin> <output.bin>\n'
85 f' {sys.argv[0]} decompress <input.bin> <output.bin>',
86 file=sys.stderr,
87 )
88 return 1
89
90 mode = sys.argv[1]
91 if mode == 'bench':
92 if len(sys.argv) < 4:
93 print(f'usage: {sys.argv[0]} bench <input.bin> <iterations>', file=sys.stderr)
94 return 1
95 input_path = sys.argv[2]
96 iterations = int(sys.argv[3])
97 if iterations <= 0:
98 print('iterations must be > 0', file=sys.stderr)
99 return 1
100 with open(input_path, 'rb') as f:
101 data = f.read()
102 start = time.perf_counter()
103 for _ in range(iterations):
104 enc = compress_lz77(data)
105 dec = decompress_lz77(enc)
106 if dec != data:
107 print('roundtrip mismatch', file=sys.stderr)
108 return 1
109 elapsed_ms = int((time.perf_counter() - start) * 1000)
110 print(f'ms={elapsed_ms}')
111 return 0
112
113 if mode == 'compress':
114 if len(sys.argv) < 4:
115 print(f'usage: {sys.argv[0]} compress <input.bin> <output.bin>', file=sys.stderr)
116 return 1
117 with open(sys.argv[2], 'rb') as f:
118 data = f.read()
119 out = compress_lz77(data)
120 with open(sys.argv[3], 'wb') as f:
121 f.write(out)
122 return 0
123
124 if mode == 'decompress':
125 if len(sys.argv) < 4:
126 print(f'usage: {sys.argv[0]} decompress <input.bin> <output.bin>', file=sys.stderr)
127 return 1
128 with open(sys.argv[2], 'rb') as f:
129 data = f.read()
130 out = decompress_lz77(data)
131 with open(sys.argv[3], 'wb') as f:
132 f.write(out)
133 return 0
134
135 print(f'unknown mode: {mode}', file=sys.stderr)
136 return 1
137
138
139if __name__ == '__main__':
140 raise SystemExit(main())
141
142