#!/usr/bin/env ruby

# Copyright (c) 2015 Alexandra Figlovskaya <fglval@gmail.com>
# Copyright (c) 2015-2019 Aleksey Cheusov <vle@gmx.net>
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

require 'optparse'

$options = {}
$fold_cnt = nil
$ratio = nil
$tmp_dir = nil
$seed = Random.new_seed
$stratified = true

OptionParser.new do |opts|
  opts.banner = <<EOF
heri-split splits the given dataset (in svmlight format)
  into training and testing sets for further evaluation.
  If option -c is specified, test<N>.txt and train<N>.txt files
  (also in svmlight format) are created, where N is the number of fold.
  If option -R is specified, test.txt and train.txt files
  are created for the same purposes.
  Also testing_fold.txt file is created, where for each object (one per line)
  its testing fold number is specified if oprion -c is applied.
  The file testing_fold.txt contain either 1 for testing set and 0
  for training set, if option -R is applied.
Usage:
    heri-split [OPTIONS] dataset1 [dataset2...]
OPTIONS:
EOF

  opts.on('-h', '--help','display this message and exit') do
    puts opts
    exit 0
  end

  opts.on("-cFOLD_CNT", "--folds=FOLD_CNT", "A number if folds (mandatory option)") do |c|
    $fold_cnt = c.to_i
  end

  opts.on("-dDIR", "--output-dir=DIR", "Output directory (mandatory option)") do |d|
    $tmp_dir = d
  end

  opts.on("-sSEED", "--seed=SEED", "Seed for pseudo-random number generator") do |s|
    if s != "" then
      $seed = s.to_i
    end
  end

  opts.on("-r", "--random", "Use random split instead of stratified") do
    $stratified = false
  end

  opts.on("-R", "--ratio=RATIO", "Split input dataset into training and testing sets
                                     with the specified RATIO (in percents)") do |r|
    $ratio = r.to_i
    $fold_cnt = 0
  end

  opts.separator " "
end.parse!

if $tmp_dir == nil or ($fold_cnt == nil and $ratio == nil) then
  STDERR.puts("Options -c/-R and -d are mandatory, see heri-split -h for details")
  exit(1)
end

$rnd = Random.new($seed)

# same as in StratifiedSplitter
$files_test = []
$files_train = []
$testing_fold =  File.open($tmp_dir+"/testing_fold.txt", 'w:ASCII-8BIT')
(1..$fold_cnt).each do |i|
  name_train = "train" + "#{i.to_i}"
  name_test = "test" + "#{i.to_i}"
  $files_test << File.open($tmp_dir+"/"+name_test+".txt", 'w:ASCII-8BIT')
  $files_train << File.open($tmp_dir+ "/"+ name_train+".txt", 'w:ASCII-8BIT')
end
if $ratio != nil
  $files_test << File.open($tmp_dir+"/test.txt", 'w:ASCII-8BIT')
  $files_train << File.open($tmp_dir+ "/train.txt", 'w:ASCII-8BIT')
end

def random_split_cv()
  nums = []
  curr_number = 0
  ARGV.each do |fn|
    File.open(fn, "r:ASCII-8BIT").each_line do |line|
      if line =~ /^([^\s]+)\s/
        nums << curr_number % $fold_cnt
        curr_number += 1
      end
    end
  end

  nums.shuffle!(random: $rnd)

  curr_number = 0
  ARGV.each do |fn|
    File.open(fn, "r:ASCII-8BIT").each_line do |line|
      if line =~ /^([^\s]+)\s/
        fold_num = nums[curr_number]
        $fold_cnt.times do |n|
          if fold_num == n
            $files_test[n].puts line
            $testing_fold.puts n+1
          else
            $files_train[n].puts line
          end
        end

        curr_number += 1
      end
    end
  end
end

def random_split_holdout()
  nums = []
  threshold = 1
  line_count = 0
  ARGV.each do |fn|
    File.open(fn, "r:ASCII-8BIT").each_line do |line|
      if line =~ /^([^\s]+)\s/
        line_count += 1
        nums << $rnd.rand()
      end
    end
  end

  threshold = (line_count.to_f * $ratio / 100).to_i
  if threshold == 0
    threshold = 1
  end
  threshold = nums.sort[threshold]

  curr_number = 0
  ARGV.each do |fn|
    File.open(fn, "r:ASCII-8BIT").each_line do |line|
      if line =~ /^([^\s]+)\s/
        if nums[curr_number] < threshold
          $files_train[0].puts line
          $testing_fold.puts '0'
        else
          $files_test[0].puts line
          $testing_fold.puts '1'
        end
        curr_number += 1
      end
    end
  end
end

def random_split()
  if $fold_cnt != nil and $fold_cnt > 0
    random_split_cv()
  else
    random_split_holdout()
  end
end

def stratified_split()
  classes = Hash.new(0)
  ARGV.each do |fn|
    File.open(fn, "r:ASCII-8BIT").each_line do |line|
      if line =~ /^([^\s]+)\s/
        classes[$1] += 1
      end
    end
  end
  classes_arr = {}
  classes.each do |x, y|
    arr = []
    y.times do |i|
      arr << i
    end
    arr.shuffle!(random: $rnd)
    classes_arr [x] = {}
    arr.each_index do |i|
      if $ratio != nil
        cnt = (y.to_f * $ratio / 100).to_i
        fold_train = (arr[i] < cnt)
        classes_arr[x][i] = fold_train
      else
        fold_train = (i * $fold_cnt.to_f / arr.size).to_i
        classes_arr[x][arr[i]] = fold_train
      end
    end
  end

  num_line = Hash.new(0)
  ARGV.each do |fn|
    File.open(fn, "r:ASCII-8BIT").each_line do |line|
      if line =~ /^([^\s]+)\s/
        curr_number = num_line[$1]
        if $ratio != nil
          if classes_arr[$1][curr_number]
            $files_train[0].puts line
            $testing_fold.puts "0"
          else
            $files_test[0].puts line
            $testing_fold.puts "1"
          end
        else
          $fold_cnt.times do |n|
            if classes_arr[$1][curr_number] == n
              $files_test[n].puts line
              $testing_fold.puts n+1
            else
              $files_train[n].puts line
            end
          end
        end
        num_line[$1] += 1
      end
    end
  end
end

if $stratified
  stratified_split()
else
  random_split()
end

$files_test.each { |x|
  x.close
}
$files_train.each { |x|
  x.close
}
$testing_fold.close
