Commit d687c276 authored by Julien Muchembled's avatar Julien Muchembled

Add support for Zstd algorithm

parent 9f0f2afe
......@@ -17,7 +17,7 @@
<key name="compress" datatype=".compress">
<description>
The value is either of 'boolean' type or an explicit algorithm that
matches the regex 'zlib(=\d+)?', where the optional number is
matches the regex '(zlib|zstd)(=\d+)?', where the optional number is
the compression level.
Any record that is not smaller once compressed is stored uncompressed.
True is the default and its meaning may change over time:
......
......@@ -16,31 +16,42 @@
import zlib
no_zstd = "no binding found for Zstd (de)compression"
try:
import zstd
zstd.maxCLevel # is there a better way to check we use Nexedi's bindings
except (AttributeError, ImportError):
zstd = None
def _no_zstd(*_):
raise ImportError(no_zstd)
decompress_list = (
lambda data: data,
zlib.decompress,
zstd.decompress if zstd else _no_zstd,
)
def parseOption(value):
x = value.split('=', 1)
try:
alg = ('zlib',).index(x[0])
alg = ('zlib', 'zstd').index(x[0])
if len(x) == 1:
return alg, None
level = int(x[1])
except Exception:
raise ValueError("not a valid 'compress' option: %r" % value)
if 0 < level <= zlib.Z_BEST_COMPRESSION:
if (0 != level <= (zstd.maxCLevel if zstd else _no_zstd)() if alg else
0 < level <= zlib.Z_BEST_COMPRESSION):
return alg, level
raise ValueError("invalid compression level: %r" % level)
def getCompress(value):
if value:
alg, level = (0, None) if value is True else value
_compress = zlib.compress
_compress = (zstd or _no_zstd() if alg else zlib).compress
if level:
zlib_compress = _compress
_compress = lambda data: zlib_compress(data, level)
module_compress = _compress
_compress = lambda data: module_compress(data, level)
alg += 1
assert 0 < alg < len(decompress_list), 'invalid compression algorithm'
def compress(data):
......
......@@ -22,7 +22,7 @@ from struct import Struct
# The protocol version must be increased whenever upgrading a node may require
# to upgrade other nodes. It is encoded as a 4-bytes big-endian integer and
# the high order byte 0 is different from TLS Handshake (0x16).
PROTOCOL_VERSION = 3
PROTOCOL_VERSION = 4
ENCODED_VERSION = Struct('!L').pack(PROTOCOL_VERSION)
# Avoid memory errors on corrupted data.
......@@ -526,6 +526,12 @@ class PBoolean(PStructItem):
"""
_fmt = '!?'
class PByte(PStructItem):
"""
A 8-bits integer number
"""
_fmt = '!B'
class PNumber(PStructItem):
"""
A integer number (4-bytes length)
......@@ -1026,7 +1032,7 @@ class RebaseObject(Packet):
PTID('serial'),
PTID('conflict_serial'),
POption('data',
PBoolean('compression'),
PByte('compression'),
PChecksum('checksum'),
PString('data'),
),
......@@ -1044,7 +1050,7 @@ class StoreObject(Packet):
_fmt = PStruct('ask_store_object',
POID('oid'),
PTID('serial'),
PBoolean('compression'),
PByte('compression'),
PChecksum('checksum'),
PString('data'),
PTID('data_serial'),
......@@ -1109,7 +1115,7 @@ class GetObject(Packet):
POID('oid'),
PTID('serial_start'),
PTID('serial_end'),
PBoolean('compression'),
PByte('compression'),
PChecksum('checksum'),
PString('data'),
PTID('data_serial'),
......@@ -1579,7 +1585,7 @@ class AddObject(Packet):
_fmt = PStruct('add_object',
POID('oid'),
PTID('serial'),
PBoolean('compression'),
PByte('compression'),
PChecksum('checksum'),
PString('data'),
PTID('data_serial'),
......
......@@ -439,7 +439,7 @@ class DatabaseManager(object):
6-tuple: Record content.
- record serial (int)
- serial or next record modifying object (int, None)
- compression (boolean-ish, None)
- compression (tiny integer, None)
- checksum (binary string, None)
- data (binary string, None)
- data_serial (int, None)
......@@ -462,7 +462,7 @@ class DatabaseManager(object):
6-tuple: Record content.
- record serial (packed)
- serial or next record modifying object (packed, None)
- compression (boolean-ish, None)
- compression (tiny integer, None)
- checksum (binary string, None)
- data (binary string, None)
- data_serial (packed, None)
......@@ -706,7 +706,7 @@ class DatabaseManager(object):
If 'checksum_or_id' is a checksum, it must be the result of
makeChecksum(data) and extra parameters must be (data, compression)
where 'compression' indicates if 'data' is compressed.
where 'compression' indicates how 'data' is compressed.
A volatile reference is set to this data until 'releaseData' is called
with this checksum.
If called with only an id, it only increment the volatile
......
......@@ -473,7 +473,9 @@ class MySQLDatabaseManager(DatabaseManager):
serial, compression, checksum, data, value_serial = r[0]
except IndexError:
return None
if compression and compression & 0x80:
if compression is None:
compression = 0
elif compression & 0x80:
compression &= 0x7f
data = ''.join(self._bigData(data))
return (serial, self._getNextTID(partition, oid, serial),
......
......@@ -331,7 +331,7 @@ class SQLiteDatabaseManager(DatabaseManager):
checksum = str(checksum)
data = str(data)
return (serial, self._getNextTID(partition, oid, serial),
compression, checksum, data, value_serial)
compression or 0, checksum, data, value_serial)
def _changePartitionTable(self, cell_list, reset=False):
q = self.query
......
......@@ -95,7 +95,7 @@ class ClientOperationHandler(BaseHandler):
def askStoreObject(self, conn, oid, serial,
compression, checksum, data, data_serial, ttid):
if 1 < compression:
if 2 < compression:
raise ProtocolError('invalid compression value')
# register the transaction
self.app.tm.register(conn, ttid)
......
......@@ -32,7 +32,7 @@ from neo.lib.connection import ConnectionClosed, \
ServerConnection, MTClientConnection
from neo.lib.exception import StoppedOperation
from neo.lib.handler import DelayEvent, EventHandler
from neo.lib import logging
from neo.lib import compress, logging
from neo.lib.protocol import (CellStates, ClusterStates, NodeStates, NodeTypes,
Packets, Packet, uuid_str, ZERO_OID, ZERO_TID, MAX_TID)
from .. import expectedFailure, unpickle_state, Patch, TransactionalResource
......@@ -60,8 +60,8 @@ class PCounterWithResolution(PCounter):
class Test(NEOThreadedTest):
def testBasicStore(self, dedup=False):
with NEOCluster(dedup=dedup) as cluster:
def testBasicStore(self, compress_alg=0, dedup=False):
with NEOCluster(compress=(compress_alg,None)) as cluster:
cluster.start()
storage = cluster.getZODBStorage()
storage.sync()
......@@ -71,7 +71,7 @@ class Test(NEOThreadedTest):
compressible = 'x' * 20
compressed = compress(compressible)
oid_list = []
if cluster.storage.getAdapter() == 'SQLite':
if cluster.storage.getAdapter() == 'SQLite' or compress_alg:
big = None
data = 'foo', '', 'foo', compressed, compressible
else:
......@@ -83,7 +83,7 @@ class Test(NEOThreadedTest):
self.assertFalse(cluster.storage.sqlCount('data'))
for data in data:
if data is compressible:
key = makeChecksum(compressed), 1
key = makeChecksum(compressed), 1 + compress_alg
else:
key = makeChecksum(data), 0
oid = storage.new_oid()
......@@ -116,6 +116,10 @@ class Test(NEOThreadedTest):
self.assertFalse(cluster.storage.sqlCount('bigdata'))
self.assertFalse(cluster.storage.sqlCount('data'))
@unittest.skipIf(compress.zstd is None, compress.no_zstd)
def testBasicStoreZstd(self):
self.testBasicStore(1)
@with_cluster()
def testDeleteObject(self, cluster):
if 1:
......
......@@ -21,6 +21,7 @@ from ZODB.config import databaseFromString
from .. import Patch
from . import ClientApplication, NEOThreadedTest, with_cluster
from neo.client import Storage
from neo.lib.compress import zstd
def databaseFromDict(**kw):
return databaseFromString("%%import neo.client\n"
......@@ -52,6 +53,11 @@ class ConfigTests(NEOThreadedTest):
def testCompress(self, cluster):
kw = self.dummy_required.copy()
valid = ['false', 'true', 'zlib', 'zlib=9']
if zstd:
valid.append('zstd')
else:
kw['compress'] = 'zstd'
self.assertRaises(ImportError, databaseFromDict, **kw)
for kw['compress'] in '9', 'best', 'zlib=0', 'zlib=100':
self.assertRaises(ConfigurationSyntaxError, databaseFromDict, **kw)
for compress in valid:
......
......@@ -40,7 +40,7 @@ class SSLTests(SSLMixin, test.Test):
# With MySQL, this test is expensive.
# Let's check deduplication of big oids here.
def testBasicStore(self):
super(SSLTests, self).testBasicStore(True)
super(SSLTests, self).testBasicStore(dedup=True)
def testAbortConnection(self, after_handshake=1):
with self.getLoopbackConnection() as conn:
......
......@@ -39,6 +39,7 @@ extras_require = {
'storage-sqlite': [],
'storage-mysqldb': ['mysqlclient'],
'storage-importer': zodb_require + ['msgpack>=0.5.6', 'setproctitle'],
'zstd': ['cython-zstd'],
}
extras_require['tests'] = ['coverage', 'zope.testing', 'psutil>=2',
'neoppod[%s]' % ', '.join(extras_require)]
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment