PageRenderTime 44ms CodeModel.GetById 22ms RepoModel.GetById 1ms app.codeStats 0ms

/src/main/scala/updown/app/experiment/NFoldExperiment.scala

https://bitbucket.org/speriosu/updown
Scala | 85 lines | 68 code | 12 blank | 5 comment | 3 complexity | aa0d25bfa1d51bc0499e0e5ef5234df2 MD5 | raw file
  1. package updown.app.experiment
  2. import updown.data.io.TweetFeatureReader
  3. import org.clapper.argot.ArgotParser._
  4. import org.clapper.argot.ArgotConverters._
  5. import com.weiglewilczek.slf4s.Logging
  6. import updown.util.Statistics
  7. import org.clapper.argot.{SingleValueOption, ArgotUsageException, ArgotParser}
  8. import updown.data.{SystemLabeledTweet, SentimentLabel, GoldLabeledTweet}
  9. abstract class NFoldExperiment extends Experiment {
  10. // this exists purely to make the ArgotConverters appear used to IDEA
  11. convertByte _
  12. // val parser = new ArgotParser(this.getClass.getName)
  13. val goldInputFile = parser.option[String](List("g", "gold"), "gold", "gold labeled input")
  14. val n = parser.option[Int](List("n", "folds"), "FOLDS", "the number of folds for the experiment (default 10)")
  15. var experimentalRun = 0
  16. def doExperiment(train: List[GoldLabeledTweet], test: List[GoldLabeledTweet]): List[SystemLabeledTweet]
  17. def generateTrials(inputFile: String, nFolds: Int): Iterator[(List[GoldLabeledTweet], List[GoldLabeledTweet])] = {
  18. val polToTweetLists = TweetFeatureReader(inputFile).groupBy((tweet) => tweet.goldLabel)
  19. val minListLength = (for ((pol, tweetList) <- polToTweetLists) yield tweetList.length).min
  20. logger.info("taking %d items from each polarity class. This was the minimum number in any class".format(minListLength))
  21. val allTweetsFolded =
  22. (for (index <- 0 until minListLength) yield {
  23. (for ((pol, tweetList) <- polToTweetLists) yield {
  24. (pol, index, (index % nFolds, tweetList(index)))
  25. }).toList.map {
  26. case (pol, index, item) => item
  27. }
  28. // this is really strange. If I just emit the item, it only emits every nth one.
  29. // Somehow, emitting a tuple and then unmapping it fixes the problem.
  30. // I'm guessing this is because the input is a map, and it is trying to make the output a map as well.
  31. }).toList.flatten
  32. val foldsToTweets = allTweetsFolded.groupBy {
  33. case (fold, tweet) => fold
  34. }
  35. .map {
  36. case (fold, list) => (fold, list.map {
  37. case (fold, tweet) => tweet
  38. })
  39. }
  40. (for ((heldOutFold, heldOutData) <- foldsToTweets) yield {
  41. (heldOutData, foldsToTweets.filter {
  42. case (setNo, list) => setNo != heldOutFold
  43. }.map {
  44. case (setNo, list) => list
  45. }.flatten.toList)
  46. }).iterator
  47. }
  48. def apply(args: Array[String]) {
  49. try {
  50. parser.parse(args)
  51. val nFolds: Int = n.value.getOrElse(10)
  52. if (goldInputFile.value == None) {
  53. parser.usage("You must specify a gold labeled input file via -g.")
  54. }
  55. val inputFile = goldInputFile.value.get
  56. val results =
  57. (for ((testSet, trainSet) <- generateTrials(inputFile, nFolds)) yield {
  58. experimentalRun += 1
  59. logger.debug("starting run " + experimentalRun)
  60. val result = doExperiment(trainSet, testSet)
  61. logger.debug("ending run " + experimentalRun)
  62. result
  63. }).toList
  64. val result = results.flatten
  65. logger.info("Final Result:")
  66. report(inputFile.toString,result)
  67. logger.debug("running cleanup code")
  68. }
  69. catch {
  70. case e: ArgotUsageException => println(e.message); sys.exit(1)
  71. }
  72. }
  73. }