# Understanding tensordot

Each Answer to this Q is separated by one/two green lines.

After I learned how to use `einsum`, I am now trying to understand how `np.tensordot` works.

However, I am a little bit lost especially regarding the various possibilities for the parameter `axes`.

To understand it, as I have never practiced tensor calculus, I use the following example:

``````A = np.random.randint(2, size=(2, 3, 5))
B = np.random.randint(2, size=(3, 2, 4))
``````

In this case, what are the different possible `np.tensordot` and how would you compute it manually?

The idea with `tensordot` is pretty simple – We input the arrays and the respective axes along which the sum-reductions are intended. The axes that take part in sum-reduction are removed in the output and all of the remaining axes from the input arrays are spread-out as different axes in the output keeping the order in which the input arrays are fed.

Let’s look at few sample cases with one and two axes of sum-reductions and also swap the input places and see how the order is kept in the output.

## I. One axis of sum-reduction

Inputs :

`````` In [7]: A = np.random.randint(2, size=(2, 6, 5))
...:  B = np.random.randint(2, size=(3, 2, 4))
...:
``````

Case #1:

``````In [9]: np.tensordot(A, B, axes=((0),(1))).shape
Out[9]: (6, 5, 3, 4)

A : (2, 6, 5) -> reduction of axis=0
B : (3, 2, 4) -> reduction of axis=1

Output : `(2, 6, 5)`, `(3, 2, 4)` ===(2 gone)==> `(6,5)` + `(3,4)` => `(6,5,3,4)`
``````

Case #2 (same as case #1 but the inputs are fed swapped):

``````In [8]: np.tensordot(B, A, axes=((1),(0))).shape
Out[8]: (3, 4, 6, 5)

B : (3, 2, 4) -> reduction of axis=1
A : (2, 6, 5) -> reduction of axis=0

Output : `(3, 2, 4)`, `(2, 6, 5)` ===(2 gone)==> `(3,4)` + `(6,5)` => `(3,4,6,5)`.
``````

## II. Two axes of sum-reduction

Inputs :

``````In [11]: A = np.random.randint(2, size=(2, 3, 5))
...: B = np.random.randint(2, size=(3, 2, 4))
...:
``````

Case #1:

``````In [12]: np.tensordot(A, B, axes=((0,1),(1,0))).shape
Out[12]: (5, 4)

A : (2, 3, 5) -> reduction of axis=(0,1)
B : (3, 2, 4) -> reduction of axis=(1,0)

Output : `(2, 3, 5)`, `(3, 2, 4)` ===(2,3 gone)==> `(5)` + `(4)` => `(5,4)`
``````

Case #2:

``````In [14]: np.tensordot(B, A, axes=((1,0),(0,1))).shape
Out[14]: (4, 5)

B : (3, 2, 4) -> reduction of axis=(1,0)
A : (2, 3, 5) -> reduction of axis=(0,1)

Output : `(3, 2, 4)`, `(2, 3, 5)` ===(2,3 gone)==> `(4)` + `(5)` => `(4,5)`
``````

We can extend this to as many axes as possible.

`tensordot` swaps axes and reshapes the inputs so it can apply `np.dot` to 2 2d arrays. It then swaps and reshapes back to the target. It may be easier to experiment than to explain. There’s no special tensor math going on, just extending `dot` to work in higher dimensions. `tensor` just means arrays with more than 2d. If you are already comfortable with `einsum` then it will be simplest compare the results to that.

A sample test, summing on 1 pair of axes

``````In [823]: np.tensordot(A,B,[0,1]).shape
Out[823]: (3, 5, 3, 4)
In [824]: np.einsum('ijk,lim',A,B).shape
Out[824]: (3, 5, 3, 4)
In [825]: np.allclose(np.einsum('ijk,lim',A,B),np.tensordot(A,B,[0,1]))
Out[825]: True
``````

another, summing on two.

``````In [826]: np.tensordot(A,B,[(0,1),(1,0)]).shape
Out[826]: (5, 4)
In [827]: np.einsum('ijk,jim',A,B).shape
Out[827]: (5, 4)
In [828]: np.allclose(np.einsum('ijk,jim',A,B),np.tensordot(A,B,[(0,1),(1,0)]))
Out[828]: True
``````

We could do same with the `(1,0)` pair. Given the mix of dimension I don’t think there’s another combination.

The answers above are great and helped me a lot in understanding `tensordot`. But they don’t show actual math behind operations. That’s why I did equivalent operations in TF 2 for myself and decided to share them here:

``````a = tf.constant([1,2.])
b = tf.constant([2,3.])
print(f"{tf.tensordot(a, b, 0)}\t tf.einsum('i,j', a, b)\t\t- ((the last 0 axes of a), (the first 0 axes of b))")
print(f"{tf.tensordot(a, b, ((),()))}\t tf.einsum('i,j', a, b)\t\t- ((() axis of a), (() axis of b))")
print(f"{tf.tensordot(b, a, 0)}\t tf.einsum('i,j->ji', a, b)\t- ((the last 0 axes of b), (the first 0 axes of a))")
print(f"{tf.tensordot(a, b, 1)}\t\t tf.einsum('i,i', a, b)\t\t- ((the last 1 axes of a), (the first 1 axes of b))")
print(f"{tf.tensordot(a, b, ((0,), (0,)))}\t\t tf.einsum('i,i', a, b)\t\t- ((0th axis of a), (0th axis of b))")
print(f"{tf.tensordot(a, b, (0,0))}\t\t tf.einsum('i,i', a, b)\t\t- ((0th axis of a), (0th axis of b))")

[[2. 3.]
[4. 6.]]    tf.einsum('i,j', a, b)     - ((the last 0 axes of a), (the first 0 axes of b))
[[2. 3.]
[4. 6.]]    tf.einsum('i,j', a, b)     - ((() axis of a), (() axis of b))
[[2. 4.]
[3. 6.]]    tf.einsum('i,j->ji', a, b) - ((the last 0 axes of b), (the first 0 axes of a))
8.0          tf.einsum('i,i', a, b)     - ((the last 1 axes of a), (the first 1 axes of b))
8.0          tf.einsum('i,i', a, b)     - ((0th axis of a), (0th axis of b))
8.0          tf.einsum('i,i', a, b)     - ((0th axis of a), (0th axis of b))
``````

And for `(2,2)` shape:

``````a = tf.constant([[1,2],
[-2,3.]])

b = tf.constant([[-2,3],
[0,4.]])
print(f"{tf.tensordot(a, b, 0)}\t tf.einsum('ij,kl', a, b)\t- ((the last 0 axes of a), (the first 0 axes of b))")
print(f"{tf.tensordot(a, b, (0,0))}\t tf.einsum('ij,ik', a, b)\t- ((0th axis of a), (0th axis of b))")
print(f"{tf.tensordot(a, b, (0,1))}\t tf.einsum('ij,ki', a, b)\t- ((0th axis of a), (1st axis of b))")
print(f"{tf.tensordot(a, b, 1)}\t tf.matmul(a, b)\t\t- ((the last 1 axes of a), (the first 1 axes of b))")
print(f"{tf.tensordot(a, b, ((1,), (0,)))}\t tf.einsum('ij,jk', a, b)\t- ((1st axis of a), (0th axis of b))")
print(f"{tf.tensordot(a, b, (1, 0))}\t tf.matmul(a, b)\t\t- ((1st axis of a), (0th axis of b))")
print(f"{tf.tensordot(a, b, 2)}\t tf.reduce_sum(tf.multiply(a, b))\t- ((the last 2 axes of a), (the first 2 axes of b))")
print(f"{tf.tensordot(a, b, ((0,1), (0,1)))}\t tf.einsum('ij,ij->', a, b)\t\t- ((0th axis of a, 1st axis of a), (0th axis of b, 1st axis of b))")
[[[[-2.  3.]
[ 0.  4.]]
[[-4.  6.]
[ 0.  8.]]]

[[[ 4. -6.]
[-0. -8.]]
[[-6.  9.]
[ 0. 12.]]]]  tf.einsum('ij,kl', a, b)   - ((the last 0 axes of a), (the first 0 axes of b))
[[-2. -5.]
[-4. 18.]]      tf.einsum('ij,ik', a, b)   - ((0th axis of a), (0th axis of b))
[[-8. -8.]
[ 5. 12.]]      tf.einsum('ij,ki', a, b)   - ((0th axis of a), (1st axis of b))
[[-2. 11.]
[ 4.  6.]]      tf.matmul(a, b)            - ((the last 1 axes of a), (the first 1 axes of b))
[[-2. 11.]
[ 4.  6.]]      tf.einsum('ij,jk', a, b)   - ((1st axis of a), (0th axis of b))
[[-2. 11.]
[ 4.  6.]]      tf.matmul(a, b)            - ((1st axis of a), (0th axis of b))
16.0    tf.reduce_sum(tf.multiply(a, b))    - ((the last 2 axes of a), (the first 2 axes of b))
16.0    tf.einsum('ij,ij->', a, b)          - ((0th axis of a, 1st axis of a), (0th axis of b, 1st axis of b))
``````

The answers/resolutions are collected from stackoverflow, are licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0 .