kwarray.util_torch module¶
Torch specific extensions
- kwarray.util_torch.one_hot_embedding(labels, num_classes, dim=1)[source]¶
Embedding labels to one-hot form.
- Parameters:
labels – (LongTensor) class labels, sized [N,].
num_classes – (int) number of classes.
dim (int) – dimension which will be created, if negative
- Returns:
encoded labels, sized [N,#classes].
- Return type:
Tensor
References
https://discuss.pytorch.org/t/convert-int-into-one-hot-format/507/4
Example
>>> # each element in target has to have 0 <= value < C >>> # xdoctest: +REQUIRES(module:torch) >>> import torch >>> labels = torch.LongTensor([0, 0, 1, 4, 2, 3]) >>> num_classes = max(labels) + 1 >>> t = one_hot_embedding(labels, num_classes) >>> assert all(row[y] == 1 for row, y in zip(t.numpy(), labels.numpy())) >>> import ubelt as ub >>> print(ub.urepr(t.numpy().tolist())) [ [1.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0, 0.0], ] >>> t2 = one_hot_embedding(labels.numpy(), num_classes) >>> assert np.all(t2 == t.numpy()) >>> from kwarray.util_torch import _torch_available_devices >>> devices = _torch_available_devices() >>> if devices: >>> device = devices[0] >>> try: >>> t3 = one_hot_embedding(labels.to(device), num_classes) >>> except RuntimeError: >>> pass >>> assert np.all(t3.cpu().numpy() == t.numpy())
Example
>>> # xdoctest: +REQUIRES(module:torch) >>> import torch >>> nC = num_classes = 3 >>> labels = (torch.rand(10, 11, 12) * nC).long() >>> assert one_hot_embedding(labels, nC, dim=0).shape == (3, 10, 11, 12) >>> assert one_hot_embedding(labels, nC, dim=1).shape == (10, 3, 11, 12) >>> assert one_hot_embedding(labels, nC, dim=2).shape == (10, 11, 3, 12) >>> assert one_hot_embedding(labels, nC, dim=3).shape == (10, 11, 12, 3) >>> labels = (torch.rand(10, 11) * nC).long() >>> assert one_hot_embedding(labels, nC, dim=0).shape == (3, 10, 11) >>> assert one_hot_embedding(labels, nC, dim=1).shape == (10, 3, 11) >>> labels = (torch.rand(10) * nC).long() >>> assert one_hot_embedding(labels, nC, dim=0).shape == (3, 10) >>> assert one_hot_embedding(labels, nC, dim=1).shape == (10, 3)
- kwarray.util_torch.one_hot_lookup(data, indices)[source]¶
Return value of a particular column for each row in data.
Each item in labels corresonds to a row in
data
. Returns the index specified at each row.- Parameters:
data (ArrayLike) – N x C float array of values
indices (ArrayLike) – N integer array between 0 and C. This is an column index for each row in
data
.
- Returns:
the selected probability for each row
- Return type:
ArrayLike
Note
This is functionally equivalent to
[row[c] for row, c in zip(data, indices)]
except that it is works with pure matrix operations.Todo
- [ ] Allow the user to specify which dimension indices should be
zipped over. By default it should be dim=0
- [ ] Allow the user to specify which dimension indices should select
from. By default it should be dim=1.
Example
>>> from kwarray.util_torch import * # NOQA >>> data = np.array([ >>> [0, 1, 2], >>> [3, 4, 5], >>> [6, 7, 8], >>> [9, 10, 11], >>> ]) >>> indices = np.array([0, 1, 2, 1]) >>> res = one_hot_lookup(data, indices) >>> print('res = {!r}'.format(res)) res = array([ 0, 4, 8, 10]) >>> alt = np.array([row[c] for row, c in zip(data, indices)]) >>> assert np.all(alt == res)
Example
>>> # xdoctest: +REQUIRES(module:torch) >>> import torch >>> data = torch.from_numpy(np.array([ >>> [0, 1, 2], >>> [3, 4, 5], >>> [6, 7, 8], >>> [9, 10, 11], >>> ])) >>> indices = torch.from_numpy(np.array([0, 1, 2, 1])).long() >>> res = one_hot_lookup(data, indices) >>> print('res = {!r}'.format(res)) res = tensor([ 0, 4, 8, 10]...) >>> alt = torch.LongTensor([row[c] for row, c in zip(data, indices)]) >>> assert torch.all(alt == res)