zero.concat

zero.concat(iterable)[source]

Concatenate items (tensors, numpy-arrays, tuples, dicts etc.) along the first dimension.

concat is a more general version of torch.cat(..., dim=0). It works not only with sequences of tensors, but also with sequences of containers (tuples, dicts etc.) of different types of data (tensors, numpy-arrays, primitive types). See the tutorial and the examples below to understand what the function does.

Parameters

iterable (Iterable[T]) – items of the same structure (for example, “an iterable of tensors” OR “an iterable of tuples of tensors where all the tuples are of the same length” OR “an iterable of dicts of tensors and numpy-arrays where all the dicts have the same keys” etc.)

Returns

Concatenated items of the iterable.

Return type

T

Note

The concatenation algorithm is fully determined by the first item of the iterable. If there are items of different structure, then the function is likely to fail or produce incorrect results, hence the requirement of the same structure for all items of the iterable.

Warning

The function starts with conversion of the iterable to a list. Make sure that you have enough memory for such operation, otherwise, memory limit may be exceeded. Note that in most cases manual implementation would involve the same conversion, just keep this in mind when using the function.

See also

iter_batches

Tutorial

For usage examples, scroll further.

If you have an iterable that contains/produces batches of some kind (tensors, numpy-arrays, tuples/dictionaries thereof and other not-too-specific content), then use concat to concatenate all the items. A prominent case is application of models and functions to batches (e.g. to DataLoader):

whole_result = concat(map(model_or_fn, batches))
# or
whole_result = concat(expression(x) for x in batches)

For example:

dataset = ...  # PyTorch dataset
loader = DataLoader(dataset, batch_size)

def step(batch):
    X, y = batch
    return model(X), y

y_pred, y = concat(map(step, loader))
assert len(y_pred) == len(dataset) and len(y) == len(dataset)

# or
def step(batch):
    X, y = batch
    return {'y_pred': model(X), 'y': y}

result = concat(map(step, loader))  # no changes
assert result['y_pred'] == len(dataset) and len(result['y']) == len(dataset)

The function can be used in combination with iter_batches. For example, this is how pairwise dot products can be calculated in a batchwise manner if full matrix multiplication does not fit into memory:

n_objects = 100
n_features = 16
batch_size = 20
data = torch.randn(n_objects, n_features)
result = concat(
    batch.matmul(data.T).to('cpu') for batch in iter_batches(data, batch_size)
)
assert result.shape == (n_objects, n_objects)

Or even like this:

n_objects = 100
n_features = 16
batch_size = 20
data = torch.randn(n_objects, n_features)
result = concat(
    concat(b.matmul(a.T).to('cpu') for b in iter_batches(data, batch_size)).T
    for a in iter_batches(data, batch_size)
)
assert result.shape == (n_objects, n_objects)

Examples

How to read the examples:

  • the mental model for understanding the following examples is “concatenating data for 3 batches of sizes (2, 2, 3)”. Note that sizes of batches are allowed to vary, but the structure is always the same

  • in all examples there is data - a list of batches; in fact, it can be any “iterable of batches”, including iterators and generators; the list is chosen to simplify the demonstration

1-D example:

result = concat([
    torch.tensor([0, 1]), torch.tensor([2, 3]), torch.tensor([4, 5, 6])
])
assert torch.equal(result, torch.tensor([0, 1, 2, 3, 4, 5, 6]))

2-D example:

result = concat([
    torch.tensor([
        [0, 0],
        [1, 1]
    ]),
    torch.tensor([
        [2, 2],
        [3, 3]
    ]),
    torch.tensor([
        [4, 4],
        [5, 5],
        [6, 6],
    ]),
])
assert torch.equal(
    result,
    torch.tensor([
        [0, 0],
        [1, 1],
        [2, 2],
        [3, 3],
        [4, 4],
        [5, 5],
        [6, 6]
    ])
)

N-D example: <the same>.

The following examples demonstrate support for different kinds of input data; data is 1-D everywhere just for simplicity (i.e. dimensions can be arbitrary).

array = np.array
tensor = torch.tensor
l = [0, 1, 2, 3, 4, 5, 6]
a = array([0, 1, 2, 3, 4, 5, 6])
t = tensor([0, 1, 2, 3, 4, 5, 6])

data = [[0, 1], [2, 3], [4, 5, 6]]
assert concat(data) == l

data = [array([0, 1]), array([2, 3]), array([4, 5, 6])]
assert np.array_equal(concat(data), a)

data = [tensor([0, 1]), tensor([2, 3]), tensor([4, 5, 6])]
assert torch.equal(concat(data), t)

# If items are not lists, arrays nor tensors, the data is returned in a form
# of a list. It makes sense since the list of such items is already
# a result for all batches.
data = ['three batches, hence three items', 0, 1.0]
assert concat(data) == data

data = [
    ([0, 1], array([0, 1]), tensor([0, 1])),
    ([2, 3], array([2, 3]), tensor([2, 3])),
    ([4, 5, 6], array([4, 5, 6]), tensor([4, 5, 6])),
]
result = concat(data)
assert isinstance(result, tuple) and len(result) == 3
assert (
    result[0] == l
    and np.array_equal(result[1], a)
    and torch.equal(result[2], t)
)

data = [
    {'l': [0, 1], 'a': array([0, 1]), 't': tensor([0, 1])},
    {'l': [2, 3], 'a': array([2, 3]), 't': tensor([2, 3])},
    {'l': [4, 5, 6], 'a': array([4, 5, 6]), 't': tensor([4, 5, 6])},
]
result = concat(data)
assert isinstance(result, dict) and list(result) == ['l', 'a', 't']
assert (
    result['l'] == l
    and np.array_equal(result['a'], a)
    and torch.equal(result['t'], t)
)