166 lines
6.2 KiB
Python
166 lines
6.2 KiB
Python
"""
|
|
class Splitter.
|
|
|
|
Convenient tool for creating and working with folds.
|
|
"""
|
|
|
|
import random
|
|
|
|
from ._fold_storage import FoldStorage
|
|
from ._fold_storage import _FoldFile
|
|
|
|
|
|
class _Splitter(object):
|
|
"""
|
|
Splitter needs providing some parameters to create folds and some "reader",
|
|
that can read source.
|
|
"""
|
|
_REST_SIZE = 100000
|
|
|
|
def __init__(self, line_reader, column_description, seed, min_folds_count):
|
|
self._line_reader = line_reader
|
|
self._line_groups_ids, self._groups_ids = self._read_groups_ids()
|
|
# line_groups_ids -- group ids of each line
|
|
# groups_ids -- set of all group ids in dataset
|
|
self._folds_storage = set()
|
|
# keeps it for removing at the end of work.
|
|
self._column_description = column_description
|
|
self._min_folds_count = min_folds_count
|
|
self._random = random.Random(seed)
|
|
|
|
# line_reader -- Reader for getting lines from file. It have to support iteration through lines.
|
|
# column_description -- The description of the features of dataset.
|
|
|
|
def _read_groups_ids(self):
|
|
"""Find all groups in dataset and which group each line belongs."""
|
|
line_groups_ids = []
|
|
groups_ids = set()
|
|
|
|
lines = self._line_reader.lines_generator()
|
|
# Need to return group id of line and line as string.
|
|
for group_id, _ in lines:
|
|
line_groups_ids.append(group_id)
|
|
groups_ids.add(group_id)
|
|
return line_groups_ids, groups_ids
|
|
|
|
def _make_learn_folds(self, fold_size, left_folds):
|
|
"""Prepare test sets for folds only for one permutation"""
|
|
count_groups = len(self._groups_ids)
|
|
if count_groups // self._min_folds_count < fold_size:
|
|
raise AttributeError('The size of fold is too big: count_groups: {}, fold_size: {}. Const: {}'.format(
|
|
count_groups, fold_size, self._min_folds_count)
|
|
)
|
|
|
|
permutation = sorted(self._groups_ids)
|
|
self._random.shuffle(permutation)
|
|
|
|
result = []
|
|
current_count_folds = min(count_groups // fold_size, left_folds)
|
|
for i in range(current_count_folds):
|
|
result.append(set(permutation[i * fold_size: (i + 1) * fold_size]))
|
|
return result
|
|
|
|
def _write_folds(self, fold_storages, num, offset):
|
|
"""Learn_set contains numbers of lines. The method itself store relevant lines from dataset to fold storage."""
|
|
|
|
generator = self._line_reader.lines_generator()
|
|
# Need to return group id of line and line as string.
|
|
|
|
for fold_storage in fold_storages:
|
|
fold_storage.open()
|
|
|
|
try:
|
|
rest_folds = []
|
|
rest_fold_file = self.create_fold(None, 'offset{}_rest'.format(offset), num)
|
|
rest_fold_file.open()
|
|
num += 1
|
|
rest_size = 0
|
|
|
|
for num_line, (_, line) in enumerate(generator):
|
|
group_id = self._line_groups_ids[num_line]
|
|
is_written = False
|
|
for fold_storage in fold_storages:
|
|
if fold_storage.contains_group_id(group_id):
|
|
fold_storage.add(line)
|
|
is_written = True
|
|
if not is_written:
|
|
rest_fold_file.add(line)
|
|
rest_size += 1
|
|
if rest_size >= self._REST_SIZE:
|
|
rest_folds.append(rest_fold_file)
|
|
rest_fold_file.close()
|
|
rest_fold_file = self.create_fold(None, 'offset{}_rest'.format(offset), num)
|
|
rest_fold_file.open()
|
|
rest_size = 0
|
|
num += 1
|
|
|
|
if rest_size > 0:
|
|
rest_fold_file.close()
|
|
rest_folds.append(rest_fold_file)
|
|
elif rest_fold_file.is_opened():
|
|
rest_fold_file.close()
|
|
finally:
|
|
for fold_storage in fold_storages:
|
|
fold_storage.close()
|
|
|
|
return rest_folds
|
|
|
|
def create_fold_sets(self, fold_size, folds_count):
|
|
"""Create all folds for all permutations."""
|
|
folds = []
|
|
passed_folds_count = 0
|
|
|
|
while passed_folds_count < folds_count:
|
|
folds.append(self._make_learn_folds(fold_size, folds_count - passed_folds_count))
|
|
current_learn_folds = folds[-1]
|
|
passed_folds_count += len(current_learn_folds)
|
|
return folds
|
|
|
|
def fold_groups_files_generator(self, folds_groups, fold_offset):
|
|
"""Create folds storages for all folds in folds_groups. Generator."""
|
|
|
|
fold_num = 0
|
|
for fold_group in folds_groups:
|
|
learn_folds = []
|
|
skipped_folds = []
|
|
for learn_set in fold_group:
|
|
fold_num += 1
|
|
if fold_offset < fold_num:
|
|
fold_file = self.create_fold(learn_set, 'fold', fold_num)
|
|
learn_folds.append(fold_file)
|
|
elif fold_offset >= fold_num:
|
|
fold_file = self.create_fold(learn_set, 'offset{}_skipped'.format(fold_offset), fold_num)
|
|
skipped_folds.append(fold_file)
|
|
|
|
rest_folds = self._write_folds(learn_folds + skipped_folds, fold_num, fold_offset)
|
|
yield learn_folds, skipped_folds, rest_folds
|
|
|
|
def create_fold(self, fold_set, name, id):
|
|
file_name = self.create_name_from_id(name, id)
|
|
fold_file = _FoldFile(fold_set,
|
|
file_name,
|
|
sep=self._line_reader.get_separator(),
|
|
column_description=self._column_description)
|
|
self._folds_storage.add(fold_file)
|
|
return fold_file
|
|
|
|
def clean_folds(self):
|
|
for file in self._folds_storage:
|
|
file.delete()
|
|
|
|
def clean(self):
|
|
FoldStorage.remove_dir()
|
|
|
|
@staticmethod
|
|
def create_name_from_id(name, id, offset=None, max_count_digits=4):
|
|
if offset is not None:
|
|
name = '{name}{:0>{max_count_digits}}_offset{offset}'.format(
|
|
id,
|
|
name=name,
|
|
max_count_digits=max_count_digits,
|
|
offset=offset
|
|
)
|
|
else:
|
|
name = '{name}{:0>{max_count_digits}}'.format(id, name=name, max_count_digits=max_count_digits)
|
|
return name
|