I am still learning tensorflow and keras, and I suspect this question has a very easy answer I’m just missing due to lack of familiarity.

I have a PrefetchDataset object:

> print(tf_test)
$ <PrefetchDataset shapes: ((None, 99), (None,)), types: (tf.float32, tf.int64)>

…made up of features and a target. I can iterate over it using a for loop:

> for example in tf_test:
>     print(example[0].numpy())
>     print(example[1].numpy())
>     exit()
$ [[-0.31 -0.94 -1.12 ... 0.18 -0.27]
   [-0.22 -0.54 -0.14 ... 0.33 -0.55]
   [-0.60 -0.02 -1.41 ... 0.21 -0.63]
   [-0.03 -0.91 -0.12 ... 0.77 -0.23]
   [-0.76 -1.48 -0.15 ... 0.38 -0.35]
   [-0.55 -0.08 -0.69 ... 0.44 -0.36]]
  [0 0 1 0 1 0 0 0 1 0 1 1 0 1 0 0 0
   0 1 1 0]

However, this is very slow. What I’d like to do is access the tensor corresponding to the class labels and turn that into a numpy array, or a list, or any sort of iterable that can be fed into scikit-learn’s classification report and/or confusion matrix:

> y_pred = model.predict(tf_test)
> print(y_pred)
$ [[0.01]
> y_pred_list = [int(x[0]) for x in y_pred]             # assumes value >= 0.5 is positive prediction
> y_true = []                                           # what I need help with
> print(sklearn.metrics.confusion_matrix(y_true, y_pred_list)

…OR access the data such that it could be used in tensorflow’s confusion matrix:

> labels = []                                           # what I need help with
> predictions = y_pred_list                             # could we just use a tensor?
> print(tf.math.confusion_matrix(labels, predictions)

In both cases, the general ability to grab the target data from the original object in a manner that isn’t computationally expensive would be very helpful (and might help with my underlying intuitions re: tensorflow and keras).

Any advice would be greatly appreciated.

You can convert it to a list with list(ds) and then recompile it as a normal Dataset with tf.data.Dataset.from_tensor_slices(list(ds)). From there your nightmare begins again but at least it’s a nightmare that other people have had before.

Note that for more complex datasets (e.g. nested dictionaries) you will need more preprocessing after calling list(ds), but this should work for the example you asked about.

This is far from a satisfying answer but unfortunately the class is entirely undocumented and none of the standard Dataset tricks work.

You can turn use map to select either the input or label from every (input, label) pair, and turn this into a list:

import tensorflow as tf
import numpy as np

inputs = np.random.rand(100, 99)
targets = np.random.rand(100)

ds = tf.data.Dataset.from_tensor_slices((inputs, targets))

X_train = list(map(lambda x: x[0], ds))
y_train = list(map(lambda x: x[1], ds))

If you want to retain the batches or extract all the labels as a single tensor you could use the following function:

def get_labels_from_tfdataset(tfdataset, batched=False):

    labels = list(map(lambda x: x[1], tfdataset)) # Get labels 

    if not batched:
        return tf.concat(labels, axis=0) # concat the list of batched labels

    return labels

You can generate lists by looping your PrefetchDataset which is train_dataset in my example;

train_data = [(example.numpy(), label.numpy()) for example, label in train_dataset]

Thus you can reach every single example and label separately by using indexes;


You can also convert them into data frame with 2 columns by using pandas

import pandas as pd
pd.DataFrame(train_data, columns=['example', 'label'])

Then, if you want to convert back your list into PrefetchFataset, you can simply use ;

dataset = tf.data.Dataset.from_generator(
lambda: train_data, ( tf.string, tf.int32)) # you should define dtypes of yours

And you can check if it worked with this ;


This is the class returned by the Dataset.prefetch() method, and it is a subclass of Dataset.

If you set skip_prefetch=Ture by passing a ReadConfig to the builder, the returned type will be _OptionsDataset instead.

read_config = tfds.ReadConfig(skip_prefetch = True)
    read_config = read_config,


You may use map() function to iterate over at once
ratings.map(lambda x: x["feature name"])