import collections
import io
import time
import csv
import numpy as np
import six
import tensorflow as tf
from tensorflow.python.lib.io import file_io
from tensorflow.python.util.compat import collections_abc
[docs]class CSVExtendedLogger(tf.keras.callbacks.Callback):
"""Callback that streams epoch results to a csv file.
Supports all values that can be represented as a string,
including 1D iterables such as np.ndarray.
Example:
.. code-block:: python
csv_logger = CSVLogger('training.log')
model.fit(X_train, Y_train, callbacks=[csv_logger])
Args:
filename: filename of the csv file, e.g. 'run/log.csv'.
separator: string used to separate elements in the csv file.
append: True: append if file exists (useful for continuing
training). False: overwrite existing file,
"""
def __init__(self, filename, separator=",", append=False):
self.sep = separator
self.filename = filename
self.append = append
self.writer = None
self.keys = None
self.append_header = True
if six.PY2:
self.file_flags = "b"
self._open_args = {}
else:
self.file_flags = ""
self._open_args = {"newline": "\n"}
self.timestamp = None
super(CSVExtendedLogger, self).__init__()
def on_train_begin(self, logs=None):
if self.append:
if file_io.file_exists(self.filename):
with open(self.filename, "r" + self.file_flags) as f:
self.append_header = not bool(len(f.readline()))
mode = "a"
else:
mode = "w"
self.csv_file = io.open(
self.filename, mode + self.file_flags, **self._open_args
)
def on_epoch_begin(self, epoch, logs=None):
self.timestamp = time.time()
def on_epoch_end(self, epoch, logs=None):
timestamp = time.time()
duration = timestamp - self.timestamp # duration of curent epoch
logs = logs or {}
def handle_value(k):
is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
if isinstance(k, six.string_types):
return k
elif isinstance(k, collections_abc.Iterable) and not is_zero_dim_ndarray:
return '"[%s]"' % (", ".join(map(str, k)))
else:
return k
if self.keys is None:
self.keys = sorted(logs.keys())
if self.model.stop_training:
# We set NA so that csv parsers do not fail for this last epoch.
logs = {k: logs[k] if k in logs else (k, "NA") for k in self.keys}
if not self.writer:
class CustomDialect(csv.excel):
delimiter = self.sep
fieldnames = ["epoch", "timestamp", "duration"] + self.keys
if six.PY2:
fieldnames = [f"{x}" for x in fieldnames]
self.writer = csv.DictWriter(
self.csv_file, fieldnames=fieldnames, dialect=CustomDialect
)
if self.append_header:
self.writer.writeheader()
row_dict = collections.OrderedDict(
{"epoch": epoch, "timestamp": timestamp, "duration": duration}
)
row_dict.update((key, handle_value(logs[key])) for key in self.keys)
self.writer.writerow(row_dict)
self.csv_file.flush()
def on_train_end(self, logs=None):
self.csv_file.close()
self.writer = None