Source code for deephyper.keras.callbacks.csv_extended_logger

import collections
import io
import time
import csv

import numpy as np
import six
import tensorflow as tf
from import file_io
from tensorflow.python.util.compat import collections_abc

[docs]class CSVExtendedLogger(tf.keras.callbacks.Callback): """Callback that streams epoch results to a csv file. Supports all values that can be represented as a string, including 1D iterables such as np.ndarray. Example: .. code-block:: python csv_logger = CSVLogger('training.log'), Y_train, callbacks=[csv_logger]) Args: filename: filename of the csv file, e.g. 'run/log.csv'. separator: string used to separate elements in the csv file. append: True: append if file exists (useful for continuing training). False: overwrite existing file, """ def __init__(self, filename, separator=",", append=False): self.sep = separator self.filename = filename self.append = append self.writer = None self.keys = None self.append_header = True if six.PY2: self.file_flags = "b" self._open_args = {} else: self.file_flags = "" self._open_args = {"newline": "\n"} self.timestamp = None super(CSVExtendedLogger, self).__init__() def on_train_begin(self, logs=None): if self.append: if file_io.file_exists(self.filename): with open(self.filename, "r" + self.file_flags) as f: self.append_header = not bool(len(f.readline())) mode = "a" else: mode = "w" self.csv_file = self.filename, mode + self.file_flags, **self._open_args ) def on_epoch_begin(self, epoch, logs=None): self.timestamp = time.time() def on_epoch_end(self, epoch, logs=None): timestamp = time.time() duration = timestamp - self.timestamp # duration of curent epoch logs = logs or {} def handle_value(k): is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0 if isinstance(k, six.string_types): return k elif isinstance(k, collections_abc.Iterable) and not is_zero_dim_ndarray: return '"[%s]"' % (", ".join(map(str, k))) else: return k if self.keys is None: self.keys = sorted(logs.keys()) if self.model.stop_training: # We set NA so that csv parsers do not fail for this last epoch. logs = {k: logs[k] if k in logs else (k, "NA") for k in self.keys} if not self.writer: class CustomDialect(csv.excel): delimiter = self.sep fieldnames = ["epoch", "timestamp", "duration"] + self.keys if six.PY2: fieldnames = [f"{x}" for x in fieldnames] self.writer = csv.DictWriter( self.csv_file, fieldnames=fieldnames, dialect=CustomDialect ) if self.append_header: self.writer.writeheader() row_dict = collections.OrderedDict( {"epoch": epoch, "timestamp": timestamp, "duration": duration} ) row_dict.update((key, handle_value(logs[key])) for key in self.keys) self.writer.writerow(row_dict) self.csv_file.flush() def on_train_end(self, logs=None): self.csv_file.close() self.writer = None