Commit 84b29c90 authored by Rusty Russell's avatar Rusty Russell

tal: tal_dup()

Sometimes you want to duplicate and enlarge in one step: particularly with
TAL_TAKE.  It's also typesafe, unlike tal_memdup().
Signed-off-by: default avatarRusty Russell <rusty@rustcorp.com.au>
parent 7ca46909
......@@ -665,6 +665,12 @@ bool tal_resize_(tal_t **ctxp, size_t size)
old_t = debug_tal(to_tal_hdr(*ctxp));
/* Don't hand silly sizes to realloc. */
if (size >> (CHAR_BIT*sizeof(size) - 1)) {
call_error("Reallocation size overflow");
return false;
}
t = resizefn(old_t, size + sizeof(struct tal_hdr));
if (!t) {
call_error("Reallocation failure");
......@@ -699,7 +705,7 @@ bool tal_resize_(tal_t **ctxp, size_t size)
char *tal_strdup(const tal_t *ctx, const char *p)
{
return tal_memdup(ctx, p, strlen(p)+1);
return tal_dup(ctx, char, p, strlen(p)+1, 0);
}
char *tal_strndup(const tal_t *ctx, const char *p, size_t n)
......@@ -708,20 +714,35 @@ char *tal_strndup(const tal_t *ctx, const char *p, size_t n)
if (strlen(p) < n)
n = strlen(p);
ret = tal_memdup(ctx, p, n+1);
ret = tal_dup(ctx, char, p, n, 1);
if (ret)
ret[n] = '\0';
return ret;
}
void *tal_memdup(const tal_t *ctx, const void *p, size_t n)
void *tal_dup_(const tal_t *ctx, const void *p, size_t n, size_t extra,
const char *label)
{
void *ret;
if (ctx == TAL_TAKE)
return (void *)p;
/* Beware overflow! */
if (n + extra < n || n + extra + sizeof(struct tal_hdr) < n) {
call_error("dup size overflow");
if (ctx == TAL_TAKE)
tal_free(p);
return NULL;
}
ret = tal_arr(ctx, char, n);
if (ctx == TAL_TAKE) {
if (unlikely(!p))
return NULL;
if (!tal_resize_((void **)&p, n + extra)) {
tal_free(p);
return NULL;
}
return (void *)p;
}
ret = tal_alloc_(ctx, n + extra, false, label);
if (ret)
memcpy(ret, p, n);
return ret;
......
......@@ -188,13 +188,18 @@ tal_t *tal_next(const tal_t *root, const tal_t *prev);
tal_t *tal_parent(const tal_t *ctx);
/**
* tal_memdup - duplicate memory.
* tal_dup - duplicate an array.
* @ctx: NULL, or tal allocated object to be parent (or TAL_TAKE).
* @p: the memory to copy
* @n: the number of bytes.
*
* @type: the type (should match type of @p!)
* @p: the array to copy
* @n: the number of sizeof(type) entries to copy.
* @extra: the number of extra sizeof(type) entries to allocate.
*/
void *tal_memdup(const tal_t *ctx, const void *p, size_t n);
#define tal_dup(ctx, type, p, n, extra) \
((type *)tal_dup_((ctx), tal_typechk_(p, type *), \
tal_sizeof_(sizeof(type), (n)), \
tal_sizeof_(sizeof(type), (extra)), \
TAL_LABEL(type, "[]")))
/**
* tal_strdup - duplicate a string
......@@ -315,12 +320,23 @@ static inline size_t tal_sizeof_(size_t size, size_t count)
#if HAVE_TYPEOF
#define tal_typeof(ptr) (__typeof__(ptr))
#if HAVE_STATEMENT_EXPR
/* Careful: ptr can be const foo *, ptype is foo *. Also, ptr could
* be an array, eg "hello". */
#define tal_typechk_(ptr, ptype) ({ __typeof__(&*(ptr)) _p = (ptype)(ptr); _p; })
#else
#define tal_typechk_(ptr, ptype) (ptr)
#endif
#else /* !HAVE_TYPEOF */
#define tal_typeof(ptr)
#define tal_typechk_(ptr, ptype) (ptr)
#endif
void *tal_alloc_(const tal_t *ctx, size_t bytes, bool clear, const char *label);
void *tal_dup_(const tal_t *ctx, const void *p, size_t n, size_t extra,
const char *label);
tal_t *tal_steal_(const tal_t *new_parent, const tal_t *t);
bool tal_resize_(tal_t **ctxp, size_t size);
......
......@@ -7,14 +7,14 @@ static int error_count;
static void my_error(const char *msg)
{
error_count++;
ok1(strstr(msg, "overflow"));
}
int main(void)
{
void *p;
int *pi, *origpi;
plan_tests(6);
plan_tests(26);
tal_set_backend(NULL, NULL, NULL, my_error);
......@@ -25,5 +25,60 @@ int main(void)
p = tal_arr(NULL, char, (size_t)-2);
ok1(!p);
ok1(error_count == 2);
/* Now try overflow cases for tal_dup. */
error_count = 0;
pi = origpi = tal_arr(NULL, int, 100);
ok1(pi);
ok1(error_count == 0);
pi = tal_dup(NULL, int, pi, (size_t)-1, 0);
ok1(!pi);
ok1(error_count == 1);
pi = tal_dup(NULL, int, pi, 0, (size_t)-1);
ok1(!pi);
ok1(error_count == 2);
pi = tal_dup(NULL, int, pi, (size_t)-1UL / sizeof(int),
(size_t)-1UL / sizeof(int));
ok1(!pi);
ok1(error_count == 3);
/* This will still overflow when tal_hdr is added. */
pi = tal_dup(NULL, int, pi, (size_t)-1UL / sizeof(int) / 2,
(size_t)-1UL / sizeof(int) / 2);
ok1(!pi);
ok1(error_count == 4);
/* Now, check that with TAL_TAKE we free old one on failure. */
pi = tal_arr(NULL, int, 100);
error_count = 0;
pi = tal_dup(TAL_TAKE, int, pi, (size_t)-1, 0);
ok1(!pi);
ok1(error_count == 1);
ok1(tal_first(NULL) == origpi && !tal_next(NULL, origpi));
pi = tal_arr(NULL, int, 100);
error_count = 0;
pi = tal_dup(TAL_TAKE, int, pi, 0, (size_t)-1);
ok1(!pi);
ok1(error_count == 1);
ok1(tal_first(NULL) == origpi && !tal_next(NULL, origpi));
pi = tal_arr(NULL, int, 100);
error_count = 0;
pi = tal_dup(TAL_TAKE, int, pi, (size_t)-1UL / sizeof(int),
(size_t)-1UL / sizeof(int));
ok1(!pi);
ok1(error_count == 1);
ok1(tal_first(NULL) == origpi && !tal_next(NULL, origpi));
pi = tal_arr(NULL, int, 100);
error_count = 0;
/* This will still overflow when tal_hdr is added. */
pi = tal_dup(TAL_TAKE, int, pi, (size_t)-1UL / sizeof(int) / 2,
(size_t)-1UL / sizeof(int) / 2);
ok1(!pi);
ok1(error_count == 1);
ok1(tal_first(NULL) == origpi && !tal_next(NULL, origpi));
return exit_status();
}
......@@ -6,7 +6,7 @@ int main(void)
{
char *parent, *c;
plan_tests(9);
plan_tests(13);
parent = tal(NULL, char);
ok1(parent);
......@@ -19,10 +19,19 @@ int main(void)
ok1(strcmp(c, "hel") == 0);
ok1(tal_parent(c) == parent);
c = tal_memdup(parent, "hello", 6);
c = tal_typechk_(parent, char *);
c = tal_dup(parent, char, "hello", 6, 0);
ok1(strcmp(c, "hello") == 0);
ok1(strcmp(tal_name(c), "char[]") == 0);
ok1(tal_parent(c) == parent);
/* Now with an extra byte. */
c = tal_dup(parent, char, "hello", 6, 1);
ok1(strcmp(c, "hello") == 0);
ok1(strcmp(tal_name(c), "char[]") == 0);
ok1(tal_parent(c) == parent);
strcat(c, "x");
c = tal_asprintf(parent, "hello %s", "there");
ok1(strcmp(c, "hello there") == 0);
ok1(tal_parent(c) == parent);
......
......@@ -6,7 +6,7 @@ int main(void)
{
char *parent, *c;
plan_tests(13);
plan_tests(15);
parent = tal(NULL, char);
ok1(parent);
......@@ -25,10 +25,15 @@ int main(void)
ok1(strcmp(c, "hel") == 0);
ok1(tal_parent(c) == parent);
c = tal_memdup(TAL_TAKE, c, 1);
c = tal_dup(TAL_TAKE, char, c, 1, 0);
ok1(c[0] == 'h');
ok1(tal_parent(c) == parent);
c = tal_dup(TAL_TAKE, char, c, 1, 2);
ok1(c[0] == 'h');
strcpy(c, "hi");
ok1(tal_parent(c) == parent);
/* No leftover allocations. */
tal_free(c);
ok1(tal_first(parent) == NULL);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment