Commit d687c276 authored by Julien Muchembled's avatar Julien Muchembled

Add support for Zstd algorithm

parent 9f0f2afe
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
<key name="compress" datatype=".compress"> <key name="compress" datatype=".compress">
<description> <description>
The value is either of 'boolean' type or an explicit algorithm that The value is either of 'boolean' type or an explicit algorithm that
matches the regex 'zlib(=\d+)?', where the optional number is matches the regex '(zlib|zstd)(=\d+)?', where the optional number is
the compression level. the compression level.
Any record that is not smaller once compressed is stored uncompressed. Any record that is not smaller once compressed is stored uncompressed.
True is the default and its meaning may change over time: True is the default and its meaning may change over time:
......
...@@ -16,31 +16,42 @@ ...@@ -16,31 +16,42 @@
import zlib import zlib
no_zstd = "no binding found for Zstd (de)compression"
try:
import zstd
zstd.maxCLevel # is there a better way to check we use Nexedi's bindings
except (AttributeError, ImportError):
zstd = None
def _no_zstd(*_):
raise ImportError(no_zstd)
decompress_list = ( decompress_list = (
lambda data: data, lambda data: data,
zlib.decompress, zlib.decompress,
zstd.decompress if zstd else _no_zstd,
) )
def parseOption(value): def parseOption(value):
x = value.split('=', 1) x = value.split('=', 1)
try: try:
alg = ('zlib',).index(x[0]) alg = ('zlib', 'zstd').index(x[0])
if len(x) == 1: if len(x) == 1:
return alg, None return alg, None
level = int(x[1]) level = int(x[1])
except Exception: except Exception:
raise ValueError("not a valid 'compress' option: %r" % value) raise ValueError("not a valid 'compress' option: %r" % value)
if 0 < level <= zlib.Z_BEST_COMPRESSION: if (0 != level <= (zstd.maxCLevel if zstd else _no_zstd)() if alg else
0 < level <= zlib.Z_BEST_COMPRESSION):
return alg, level return alg, level
raise ValueError("invalid compression level: %r" % level) raise ValueError("invalid compression level: %r" % level)
def getCompress(value): def getCompress(value):
if value: if value:
alg, level = (0, None) if value is True else value alg, level = (0, None) if value is True else value
_compress = zlib.compress _compress = (zstd or _no_zstd() if alg else zlib).compress
if level: if level:
zlib_compress = _compress module_compress = _compress
_compress = lambda data: zlib_compress(data, level) _compress = lambda data: module_compress(data, level)
alg += 1 alg += 1
assert 0 < alg < len(decompress_list), 'invalid compression algorithm' assert 0 < alg < len(decompress_list), 'invalid compression algorithm'
def compress(data): def compress(data):
......
...@@ -22,7 +22,7 @@ from struct import Struct ...@@ -22,7 +22,7 @@ from struct import Struct
# The protocol version must be increased whenever upgrading a node may require # The protocol version must be increased whenever upgrading a node may require
# to upgrade other nodes. It is encoded as a 4-bytes big-endian integer and # to upgrade other nodes. It is encoded as a 4-bytes big-endian integer and
# the high order byte 0 is different from TLS Handshake (0x16). # the high order byte 0 is different from TLS Handshake (0x16).
PROTOCOL_VERSION = 3 PROTOCOL_VERSION = 4
ENCODED_VERSION = Struct('!L').pack(PROTOCOL_VERSION) ENCODED_VERSION = Struct('!L').pack(PROTOCOL_VERSION)
# Avoid memory errors on corrupted data. # Avoid memory errors on corrupted data.
...@@ -526,6 +526,12 @@ class PBoolean(PStructItem): ...@@ -526,6 +526,12 @@ class PBoolean(PStructItem):
""" """
_fmt = '!?' _fmt = '!?'
class PByte(PStructItem):
"""
A 8-bits integer number
"""
_fmt = '!B'
class PNumber(PStructItem): class PNumber(PStructItem):
""" """
A integer number (4-bytes length) A integer number (4-bytes length)
...@@ -1026,7 +1032,7 @@ class RebaseObject(Packet): ...@@ -1026,7 +1032,7 @@ class RebaseObject(Packet):
PTID('serial'), PTID('serial'),
PTID('conflict_serial'), PTID('conflict_serial'),
POption('data', POption('data',
PBoolean('compression'), PByte('compression'),
PChecksum('checksum'), PChecksum('checksum'),
PString('data'), PString('data'),
), ),
...@@ -1044,7 +1050,7 @@ class StoreObject(Packet): ...@@ -1044,7 +1050,7 @@ class StoreObject(Packet):
_fmt = PStruct('ask_store_object', _fmt = PStruct('ask_store_object',
POID('oid'), POID('oid'),
PTID('serial'), PTID('serial'),
PBoolean('compression'), PByte('compression'),
PChecksum('checksum'), PChecksum('checksum'),
PString('data'), PString('data'),
PTID('data_serial'), PTID('data_serial'),
...@@ -1109,7 +1115,7 @@ class GetObject(Packet): ...@@ -1109,7 +1115,7 @@ class GetObject(Packet):
POID('oid'), POID('oid'),
PTID('serial_start'), PTID('serial_start'),
PTID('serial_end'), PTID('serial_end'),
PBoolean('compression'), PByte('compression'),
PChecksum('checksum'), PChecksum('checksum'),
PString('data'), PString('data'),
PTID('data_serial'), PTID('data_serial'),
...@@ -1579,7 +1585,7 @@ class AddObject(Packet): ...@@ -1579,7 +1585,7 @@ class AddObject(Packet):
_fmt = PStruct('add_object', _fmt = PStruct('add_object',
POID('oid'), POID('oid'),
PTID('serial'), PTID('serial'),
PBoolean('compression'), PByte('compression'),
PChecksum('checksum'), PChecksum('checksum'),
PString('data'), PString('data'),
PTID('data_serial'), PTID('data_serial'),
......
...@@ -439,7 +439,7 @@ class DatabaseManager(object): ...@@ -439,7 +439,7 @@ class DatabaseManager(object):
6-tuple: Record content. 6-tuple: Record content.
- record serial (int) - record serial (int)
- serial or next record modifying object (int, None) - serial or next record modifying object (int, None)
- compression (boolean-ish, None) - compression (tiny integer, None)
- checksum (binary string, None) - checksum (binary string, None)
- data (binary string, None) - data (binary string, None)
- data_serial (int, None) - data_serial (int, None)
...@@ -462,7 +462,7 @@ class DatabaseManager(object): ...@@ -462,7 +462,7 @@ class DatabaseManager(object):
6-tuple: Record content. 6-tuple: Record content.
- record serial (packed) - record serial (packed)
- serial or next record modifying object (packed, None) - serial or next record modifying object (packed, None)
- compression (boolean-ish, None) - compression (tiny integer, None)
- checksum (binary string, None) - checksum (binary string, None)
- data (binary string, None) - data (binary string, None)
- data_serial (packed, None) - data_serial (packed, None)
...@@ -706,7 +706,7 @@ class DatabaseManager(object): ...@@ -706,7 +706,7 @@ class DatabaseManager(object):
If 'checksum_or_id' is a checksum, it must be the result of If 'checksum_or_id' is a checksum, it must be the result of
makeChecksum(data) and extra parameters must be (data, compression) makeChecksum(data) and extra parameters must be (data, compression)
where 'compression' indicates if 'data' is compressed. where 'compression' indicates how 'data' is compressed.
A volatile reference is set to this data until 'releaseData' is called A volatile reference is set to this data until 'releaseData' is called
with this checksum. with this checksum.
If called with only an id, it only increment the volatile If called with only an id, it only increment the volatile
......
...@@ -473,7 +473,9 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -473,7 +473,9 @@ class MySQLDatabaseManager(DatabaseManager):
serial, compression, checksum, data, value_serial = r[0] serial, compression, checksum, data, value_serial = r[0]
except IndexError: except IndexError:
return None return None
if compression and compression & 0x80: if compression is None:
compression = 0
elif compression & 0x80:
compression &= 0x7f compression &= 0x7f
data = ''.join(self._bigData(data)) data = ''.join(self._bigData(data))
return (serial, self._getNextTID(partition, oid, serial), return (serial, self._getNextTID(partition, oid, serial),
......
...@@ -331,7 +331,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -331,7 +331,7 @@ class SQLiteDatabaseManager(DatabaseManager):
checksum = str(checksum) checksum = str(checksum)
data = str(data) data = str(data)
return (serial, self._getNextTID(partition, oid, serial), return (serial, self._getNextTID(partition, oid, serial),
compression, checksum, data, value_serial) compression or 0, checksum, data, value_serial)
def _changePartitionTable(self, cell_list, reset=False): def _changePartitionTable(self, cell_list, reset=False):
q = self.query q = self.query
......
...@@ -95,7 +95,7 @@ class ClientOperationHandler(BaseHandler): ...@@ -95,7 +95,7 @@ class ClientOperationHandler(BaseHandler):
def askStoreObject(self, conn, oid, serial, def askStoreObject(self, conn, oid, serial,
compression, checksum, data, data_serial, ttid): compression, checksum, data, data_serial, ttid):
if 1 < compression: if 2 < compression:
raise ProtocolError('invalid compression value') raise ProtocolError('invalid compression value')
# register the transaction # register the transaction
self.app.tm.register(conn, ttid) self.app.tm.register(conn, ttid)
......
...@@ -32,7 +32,7 @@ from neo.lib.connection import ConnectionClosed, \ ...@@ -32,7 +32,7 @@ from neo.lib.connection import ConnectionClosed, \
ServerConnection, MTClientConnection ServerConnection, MTClientConnection
from neo.lib.exception import StoppedOperation from neo.lib.exception import StoppedOperation
from neo.lib.handler import DelayEvent, EventHandler from neo.lib.handler import DelayEvent, EventHandler
from neo.lib import logging from neo.lib import compress, logging
from neo.lib.protocol import (CellStates, ClusterStates, NodeStates, NodeTypes, from neo.lib.protocol import (CellStates, ClusterStates, NodeStates, NodeTypes,
Packets, Packet, uuid_str, ZERO_OID, ZERO_TID, MAX_TID) Packets, Packet, uuid_str, ZERO_OID, ZERO_TID, MAX_TID)
from .. import expectedFailure, unpickle_state, Patch, TransactionalResource from .. import expectedFailure, unpickle_state, Patch, TransactionalResource
...@@ -60,8 +60,8 @@ class PCounterWithResolution(PCounter): ...@@ -60,8 +60,8 @@ class PCounterWithResolution(PCounter):
class Test(NEOThreadedTest): class Test(NEOThreadedTest):
def testBasicStore(self, dedup=False): def testBasicStore(self, compress_alg=0, dedup=False):
with NEOCluster(dedup=dedup) as cluster: with NEOCluster(compress=(compress_alg,None)) as cluster:
cluster.start() cluster.start()
storage = cluster.getZODBStorage() storage = cluster.getZODBStorage()
storage.sync() storage.sync()
...@@ -71,7 +71,7 @@ class Test(NEOThreadedTest): ...@@ -71,7 +71,7 @@ class Test(NEOThreadedTest):
compressible = 'x' * 20 compressible = 'x' * 20
compressed = compress(compressible) compressed = compress(compressible)
oid_list = [] oid_list = []
if cluster.storage.getAdapter() == 'SQLite': if cluster.storage.getAdapter() == 'SQLite' or compress_alg:
big = None big = None
data = 'foo', '', 'foo', compressed, compressible data = 'foo', '', 'foo', compressed, compressible
else: else:
...@@ -83,7 +83,7 @@ class Test(NEOThreadedTest): ...@@ -83,7 +83,7 @@ class Test(NEOThreadedTest):
self.assertFalse(cluster.storage.sqlCount('data')) self.assertFalse(cluster.storage.sqlCount('data'))
for data in data: for data in data:
if data is compressible: if data is compressible:
key = makeChecksum(compressed), 1 key = makeChecksum(compressed), 1 + compress_alg
else: else:
key = makeChecksum(data), 0 key = makeChecksum(data), 0
oid = storage.new_oid() oid = storage.new_oid()
...@@ -116,6 +116,10 @@ class Test(NEOThreadedTest): ...@@ -116,6 +116,10 @@ class Test(NEOThreadedTest):
self.assertFalse(cluster.storage.sqlCount('bigdata')) self.assertFalse(cluster.storage.sqlCount('bigdata'))
self.assertFalse(cluster.storage.sqlCount('data')) self.assertFalse(cluster.storage.sqlCount('data'))
@unittest.skipIf(compress.zstd is None, compress.no_zstd)
def testBasicStoreZstd(self):
self.testBasicStore(1)
@with_cluster() @with_cluster()
def testDeleteObject(self, cluster): def testDeleteObject(self, cluster):
if 1: if 1:
......
...@@ -21,6 +21,7 @@ from ZODB.config import databaseFromString ...@@ -21,6 +21,7 @@ from ZODB.config import databaseFromString
from .. import Patch from .. import Patch
from . import ClientApplication, NEOThreadedTest, with_cluster from . import ClientApplication, NEOThreadedTest, with_cluster
from neo.client import Storage from neo.client import Storage
from neo.lib.compress import zstd
def databaseFromDict(**kw): def databaseFromDict(**kw):
return databaseFromString("%%import neo.client\n" return databaseFromString("%%import neo.client\n"
...@@ -52,6 +53,11 @@ class ConfigTests(NEOThreadedTest): ...@@ -52,6 +53,11 @@ class ConfigTests(NEOThreadedTest):
def testCompress(self, cluster): def testCompress(self, cluster):
kw = self.dummy_required.copy() kw = self.dummy_required.copy()
valid = ['false', 'true', 'zlib', 'zlib=9'] valid = ['false', 'true', 'zlib', 'zlib=9']
if zstd:
valid.append('zstd')
else:
kw['compress'] = 'zstd'
self.assertRaises(ImportError, databaseFromDict, **kw)
for kw['compress'] in '9', 'best', 'zlib=0', 'zlib=100': for kw['compress'] in '9', 'best', 'zlib=0', 'zlib=100':
self.assertRaises(ConfigurationSyntaxError, databaseFromDict, **kw) self.assertRaises(ConfigurationSyntaxError, databaseFromDict, **kw)
for compress in valid: for compress in valid:
......
...@@ -40,7 +40,7 @@ class SSLTests(SSLMixin, test.Test): ...@@ -40,7 +40,7 @@ class SSLTests(SSLMixin, test.Test):
# With MySQL, this test is expensive. # With MySQL, this test is expensive.
# Let's check deduplication of big oids here. # Let's check deduplication of big oids here.
def testBasicStore(self): def testBasicStore(self):
super(SSLTests, self).testBasicStore(True) super(SSLTests, self).testBasicStore(dedup=True)
def testAbortConnection(self, after_handshake=1): def testAbortConnection(self, after_handshake=1):
with self.getLoopbackConnection() as conn: with self.getLoopbackConnection() as conn:
......
...@@ -39,6 +39,7 @@ extras_require = { ...@@ -39,6 +39,7 @@ extras_require = {
'storage-sqlite': [], 'storage-sqlite': [],
'storage-mysqldb': ['mysqlclient'], 'storage-mysqldb': ['mysqlclient'],
'storage-importer': zodb_require + ['msgpack>=0.5.6', 'setproctitle'], 'storage-importer': zodb_require + ['msgpack>=0.5.6', 'setproctitle'],
'zstd': ['cython-zstd'],
} }
extras_require['tests'] = ['coverage', 'zope.testing', 'psutil>=2', extras_require['tests'] = ['coverage', 'zope.testing', 'psutil>=2',
'neoppod[%s]' % ', '.join(extras_require)] 'neoppod[%s]' % ', '.join(extras_require)]
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment