#!/usr/bin/env python3
# Copyright (c) Megvii Inc. All rights reserved.
import functools
import os
from collections import defaultdict, deque
import numpy as np
import torch
__all__ = [
'AverageMeter',
'MeterBuffer',
'get_total_and_free_memory_in_Mb',
'gpu_mem_usage',
]
[ドキュメント]
def get_total_and_free_memory_in_Mb(cuda_device):
devices_info_str = os.popen(
'nvidia-smi --query-gpu=memory.total,memory.used \
--format=csv,nounits,noheader')
devices_info = devices_info_str.read().strip().split('\n')
total, used = devices_info[int(cuda_device)].split(',')
return int(total), int(used)
[ドキュメント]
def gpu_mem_usage():
"""
Compute the GPU memory usage for the current device (MB).
"""
mem_usage_bytes = torch.cuda.max_memory_allocated()
return mem_usage_bytes / (1024 * 1024)
[ドキュメント]
class AverageMeter:
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=50):
self._deque = deque(maxlen=window_size)
self._total = 0.0
self._count = 0
[ドキュメント]
def update(self, value):
self._deque.append(value)
self._count += 1
self._total += value
@property
def median(self):
d = np.array(list(self._deque))
return np.median(d)
@property
def avg(self):
# if deque is empty, nan will be returned.
d = np.array(list(self._deque))
return d.mean()
@property
def global_avg(self):
return self._total / max(self._count, 1e-5)
@property
def latest(self):
return self._deque[-1] if len(self._deque) > 0 else None
@property
def total(self):
return self._total
[ドキュメント]
def reset(self):
self._deque.clear()
self._total = 0.0
self._count = 0
[ドキュメント]
def clear(self):
self._deque.clear()
[ドキュメント]
class MeterBuffer(defaultdict):
"""Computes and stores the average and current value"""
def __init__(self, window_size=20):
factory = functools.partial(AverageMeter, window_size=window_size)
super().__init__(factory)
[ドキュメント]
def reset(self):
for v in self.values():
v.reset()
[ドキュメント]
def get_filtered_meter(self, filter_key='time'):
return {k: v for k, v in self.items() if filter_key in k}
[ドキュメント]
def update(self, values=None, **kwargs):
if values is None:
values = {}
values.update(kwargs)
for k, v in values.items():
if isinstance(v, torch.Tensor):
v = v.detach()
self[k].update(v)
[ドキュメント]
def clear_meters(self):
for v in self.values():
v.clear()