Source code for catalyst.contrib.tools.tensorboard
from typing import BinaryIO, Optional, Union
from collections import namedtuple
from collections.abc import Iterable
from pathlib import Path
import struct
import numpy as np
from tensorboardX import SummaryWriter as tensorboardX_SummaryWriter
from tensorboardX.crc32c import crc32c
from tensorboardX.proto.event_pb2 import Event
SummaryWriter = tensorboardX_SummaryWriter
def _u32(x):
return x & 0xFFFFFFFF
def _masked_crc32c(data):
x = _u32(crc32c(data))
return _u32(((x >> 15) | _u32(x << 17)) + 0xA282EAD8)
[docs]class EventReadingException(Exception):
"""An exception that correspond to an event file reading error."""
[docs]class EventsFileReader(Iterable):
"""An iterator over a Tensorboard events file."""
[docs] def __init__(self, events_file: BinaryIO):
"""Initialize an iterator over an events file.
Args:
events_file: An opened file-like object.
"""
self._events_file = events_file
def _read(self, size: int) -> Optional[bytes]:
"""Read exactly next `size` bytes from the current stream.
Args:
size: A size in bytes to be read.
Returns:
A `bytes` object with read data or `None` on EOF.
"""
data = self._events_file.read(size)
if data is None:
raise NotImplementedError(
"Reading of a stream in non-blocking mode"
)
if 0 < len(data) < size:
raise EventReadingException(
"The size of read data is less than requested size"
)
if len(data) == 0:
return None
return data
def _read_and_check(self, size: int) -> Optional[bytes]:
"""Read and check data described by a format string.
Args:
size: A size in bytes to be read.
Returns:
A decoded number.
"""
data = self._read(size)
if data is None:
return None
checksum_size = struct.calcsize("I")
checksum = struct.unpack("I", self._read(checksum_size))[0]
checksum_computed = _masked_crc32c(data)
if checksum != checksum_computed:
raise EventReadingException(
"Invalid checksum. {checksum} != {crc32}".format(
checksum=checksum, crc32=checksum_computed
)
)
return data
def __iter__(self) -> Event:
"""Iterates over events in the current events file.
Returns:
An Event object
"""
while True:
header_size = struct.calcsize("Q")
header = self._read_and_check(header_size)
if header is None:
break
event_size = struct.unpack("Q", header)[0]
event_raw = self._read_and_check(event_size)
if event_raw is None:
raise EventReadingException("Unexpected end of events file")
event = Event()
event.ParseFromString(event_raw)
yield event
SummaryItem = namedtuple(
"SummaryItem", ["tag", "step", "wall_time", "value", "type"]
)
def _get_scalar(value) -> Optional[np.ndarray]:
"""Decode an scalar event.
Args:
value: A value field of an event
Returns:
Decoded scalar
"""
if value.HasField("simple_value"):
return value.simple_value
return None
[docs]class SummaryReader(Iterable):
"""Iterates over events in all the files in the current logdir.
.. note::
Only scalars are supported at the moment.
"""
_DECODERS = { # noqa: WPS115
"scalar": _get_scalar,
}
[docs] def __init__(
self,
logdir: Union[str, Path],
tag_filter: Optional[Iterable] = None,
types: Iterable = ("scalar",),
):
"""Initalize new summary reader.
Args:
logdir: A directory with Tensorboard summary data
tag_filter: A list of tags to leave (`None` for all)
types: A list of types to get.
Only "scalar" and "image" types are allowed at the moment.
"""
self._logdir = Path(logdir)
self._tag_filter = set(tag_filter) if tag_filter is not None else None
self._types = set(types)
self._check_type_names()
def _check_type_names(self):
if self._types is None:
return
if not all(
type_name in self._DECODERS.keys() for type_name in self._types
):
raise ValueError("Invalid type name")
def _decode_events(self, events: Iterable) -> Optional[SummaryItem]:
"""
Convert events to `SummaryItem` instances.
Returns a generator with decoded events
or `None` if an event can't be decoded.
Args:
events: An iterable with events objects.
Returns:
Optional[SummaryItem]: decoded event
"""
for event in events:
if not event.HasField("summary"):
yield None
step = event.step
wall_time = event.wall_time
for value in event.summary.value:
tag = value.tag
for value_type in self._types:
decoder = self._DECODERS[value_type]
data = decoder(value)
if data is not None:
yield SummaryItem(
tag=tag,
step=step,
wall_time=wall_time,
value=data,
type=value_type,
)
else:
yield None
def _check_tag(self, tag: str) -> bool:
"""Check if a tag matches the current tag filter.
Args:
tag: A string with tag
Returns:
A boolean value.
"""
return self._tag_filter is None or tag in self._tag_filter
def __iter__(self) -> SummaryItem:
"""Iterate over events in all the files in the current logdir.
Returns:
A generator with `SummaryItem` objects
"""
log_files = sorted(f for f in self._logdir.glob("*") if f.is_file())
for file_path in log_files:
with open(file_path, "rb") as f:
reader = EventsFileReader(f)
yield from (
item
for item in self._decode_events(reader)
if item is not None
and self._check_tag(item.tag)
and item.type in self._types
)
# __all__ = [
# "EventReadingException",
# "EventsFileReader",
# "SummaryItem",
# "SummaryReader",
# "SummaryWriter",
# ]