PageRenderTime 1269ms CodeModel.GetById 32ms RepoModel.GetById 9ms app.codeStats 0ms

/components/optimization_guide/core/page_topics_model_executor.cc

https://github.com/chromium/chromium
C++ | 376 lines | 287 code | 57 blank | 32 comment | 41 complexity | fedc4246aacc3bd13d34a0d9e558b563 MD5 | raw file
Possible License(s): MPL-2.0-no-copyleft-exception, Apache-2.0, BSD-3-Clause
  1. // Copyright 2021 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "components/optimization_guide/core/page_topics_model_executor.h"
  5. #include "base/barrier_closure.h"
  6. #include "base/containers/contains.h"
  7. #include "base/files/file_util.h"
  8. #include "base/strings/string_number_conversions.h"
  9. #include "base/strings/string_util.h"
  10. #include "components/optimization_guide/core/optimization_guide_model_provider.h"
  11. #include "components/optimization_guide/proto/models.pb.h"
  12. #include "components/optimization_guide/proto/page_topics_model_metadata.pb.h"
  13. #include "components/optimization_guide/proto/page_topics_override_list.pb.h"
  14. #include "third_party/zlib/google/compression_utils.h"
  15. namespace optimization_guide {
  16. namespace {
  17. // The ID of the NONE category in the taxonomy. This node always exists.
  18. // Semantically, the none category is attached to data for which we can say
  19. // with certainty that no single label in the taxonomy is appropriate.
  20. const int32_t kNoneCategoryId = -2;
  21. const base::FilePath::CharType kOverrideListBasePath[] =
  22. FILE_PATH_LITERAL("override_list.pb.gz");
  23. // The result of an override list file load attempt. These values are logged to
  24. // UMA histograms, do not change or reorder values. Make sure to update
  25. // |OptimizationGuidePageTopicsOverrideListFileLoadResult| in
  26. // //tools/metrics/histograms/enums.xml.
  27. enum class OverrideListFileLoadResult {
  28. kUnknown = 0,
  29. kSuccess = 1,
  30. kCouldNotReadFile = 2,
  31. kCouldNotUncompressFile = 3,
  32. kCouldNotUnmarshalProtobuf = 4,
  33. kMaxValue = kCouldNotUnmarshalProtobuf,
  34. };
  35. void RecordOverrideListFileLoadResult(OverrideListFileLoadResult result) {
  36. base::UmaHistogramEnumeration(
  37. "OptimizationGuide.PageTopicsOverrideList.FileLoadResult", result);
  38. }
  39. absl::optional<std::unordered_map<std::string, std::vector<WeightedIdentifier>>>
  40. LoadOverrideListFromFile(const base::FilePath& path) {
  41. if (!path.IsAbsolute() ||
  42. path.BaseName() != base::FilePath(kOverrideListBasePath)) {
  43. NOTREACHED();
  44. // This is enforced by calling code, so no UMA in this case.
  45. return absl::nullopt;
  46. }
  47. std::string file_contents;
  48. if (!base::ReadFileToString(path, &file_contents)) {
  49. RecordOverrideListFileLoadResult(
  50. OverrideListFileLoadResult::kCouldNotReadFile);
  51. return absl::nullopt;
  52. }
  53. if (!compression::GzipUncompress(file_contents, &file_contents)) {
  54. RecordOverrideListFileLoadResult(
  55. OverrideListFileLoadResult::kCouldNotUncompressFile);
  56. return absl::nullopt;
  57. }
  58. proto::PageTopicsOverrideList override_list_pb;
  59. if (!override_list_pb.ParseFromString(file_contents)) {
  60. RecordOverrideListFileLoadResult(
  61. OverrideListFileLoadResult::kCouldNotUnmarshalProtobuf);
  62. return absl::nullopt;
  63. }
  64. std::unordered_map<std::string, std::vector<WeightedIdentifier>>
  65. override_list;
  66. for (const proto::PageTopicsOverrideEntry& entry :
  67. override_list_pb.entries()) {
  68. std::vector<WeightedIdentifier> topics;
  69. topics.reserve(entry.topics().topic_ids_size());
  70. for (int32_t topic : entry.topics().topic_ids()) {
  71. // Always give overridden topics full weight.
  72. topics.emplace_back(WeightedIdentifier(topic, 1.0));
  73. }
  74. override_list.emplace(entry.domain(), std::move(topics));
  75. }
  76. RecordOverrideListFileLoadResult(OverrideListFileLoadResult::kSuccess);
  77. return override_list;
  78. }
  79. } // namespace
  80. PageTopicsModelExecutor::PageTopicsModelExecutor(
  81. OptimizationGuideModelProvider* model_provider,
  82. scoped_refptr<base::SequencedTaskRunner> background_task_runner,
  83. const absl::optional<proto::Any>& model_metadata)
  84. : BertModelHandler(model_provider,
  85. background_task_runner,
  86. proto::OPTIMIZATION_TARGET_PAGE_TOPICS_V2,
  87. model_metadata),
  88. background_task_runner_(background_task_runner) {
  89. SetShouldUnloadModelOnComplete(false);
  90. }
  91. PageTopicsModelExecutor::~PageTopicsModelExecutor() = default;
  92. void PageTopicsModelExecutor::ExecuteJob(
  93. base::OnceClosure on_job_complete_callback,
  94. std::unique_ptr<PageContentAnnotationJob> job) {
  95. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  96. DCHECK_EQ(job->type(), AnnotationType::kPageTopics);
  97. // Check if there is an override list available but not loaded yet.
  98. if (override_list_file_path_ && !override_list_) {
  99. background_task_runner_->PostTaskAndReplyWithResult(
  100. FROM_HERE,
  101. base::BindOnce(&LoadOverrideListFromFile, *override_list_file_path_),
  102. base::BindOnce(&PageTopicsModelExecutor::OnOverrideListLoadAttemptDone,
  103. weak_ptr_factory_.GetWeakPtr(),
  104. std::move(on_job_complete_callback), std::move(job)));
  105. return;
  106. }
  107. PageContentAnnotationJobExecutor::ExecuteJob(
  108. std::move(on_job_complete_callback), std::move(job));
  109. }
  110. // static
  111. std::string PageTopicsModelExecutor::PreprocessHost(const std::string& host) {
  112. std::string output = base::ToLowerASCII(host);
  113. // Strip the 'www.' if it exists.
  114. if (base::StartsWith(output, "www.")) {
  115. output = output.substr(4);
  116. }
  117. static const char kCharsToReplaceWithSpace[] = {'-', '_', '.', '+'};
  118. for (char c : kCharsToReplaceWithSpace) {
  119. std::replace(output.begin(), output.end(), c, ' ');
  120. }
  121. return output;
  122. }
  123. void PageTopicsModelExecutor::ExecuteOnSingleInput(
  124. AnnotationType annotation_type,
  125. const std::string& raw_input,
  126. base::OnceCallback<void(const BatchAnnotationResult&)> callback) {
  127. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  128. DCHECK_EQ(annotation_type, AnnotationType::kPageTopics);
  129. // |processed_input| is needed by the override list and the model, but we pass
  130. // the |raw_input| to where the BatchAnnotationResult is created so that the
  131. // original input is passed back to the caller.
  132. std::string processed_input = PreprocessHost(raw_input);
  133. if (override_list_) {
  134. DCHECK(override_list_file_path_);
  135. auto iter = override_list_->find(processed_input);
  136. base::UmaHistogramBoolean(
  137. "OptimizationGuide.PageTopicsOverrideList.UsedOverride",
  138. iter != override_list_->end());
  139. if (iter != override_list_->end()) {
  140. std::move(callback).Run(BatchAnnotationResult::CreatePageTopicsResult(
  141. raw_input, iter->second));
  142. return;
  143. }
  144. }
  145. ExecuteModelWithInput(
  146. base::BindOnce(&PageTopicsModelExecutor::
  147. PostprocessCategoriesToBatchAnnotationResult,
  148. weak_ptr_factory_.GetWeakPtr(), std::move(callback),
  149. annotation_type, raw_input),
  150. processed_input);
  151. }
  152. void PageTopicsModelExecutor::OnOverrideListLoadAttemptDone(
  153. base::OnceClosure on_job_complete_callback,
  154. std::unique_ptr<PageContentAnnotationJob> job,
  155. absl::optional<
  156. std::unordered_map<std::string, std::vector<WeightedIdentifier>>>
  157. override_list) {
  158. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  159. override_list_ = override_list;
  160. if (!override_list) {
  161. // Clear the file path so we don't try to load it again.
  162. override_list_file_path_ = absl::nullopt;
  163. }
  164. // Now we're ready to run the job! Call the base class to do so.
  165. PageContentAnnotationJobExecutor::ExecuteJob(
  166. std::move(on_job_complete_callback), std::move(job));
  167. }
  168. void PageTopicsModelExecutor::PostprocessCategoriesToBatchAnnotationResult(
  169. base::OnceCallback<void(const BatchAnnotationResult&)> callback,
  170. AnnotationType annotation_type,
  171. const std::string& raw_input,
  172. const absl::optional<std::vector<tflite::task::core::Category>>& output) {
  173. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  174. DCHECK_EQ(annotation_type, AnnotationType::kPageTopics);
  175. absl::optional<std::vector<WeightedIdentifier>> categories;
  176. if (output) {
  177. categories = ExtractCategoriesFromModelOutput(*output);
  178. }
  179. std::move(callback).Run(
  180. BatchAnnotationResult::CreatePageTopicsResult(raw_input, categories));
  181. }
  182. absl::optional<std::vector<WeightedIdentifier>>
  183. PageTopicsModelExecutor::ExtractCategoriesFromModelOutput(
  184. const std::vector<tflite::task::core::Category>& model_output) const {
  185. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  186. absl::optional<proto::PageTopicsModelMetadata> model_metadata =
  187. ParsedSupportedFeaturesForLoadedModel<proto::PageTopicsModelMetadata>();
  188. if (!model_metadata) {
  189. return absl::nullopt;
  190. }
  191. absl::optional<std::string> visibility_category_name =
  192. model_metadata->output_postprocessing_params().has_visibility_params() &&
  193. model_metadata->output_postprocessing_params()
  194. .visibility_params()
  195. .has_category_name()
  196. ? absl::make_optional(model_metadata->output_postprocessing_params()
  197. .visibility_params()
  198. .category_name())
  199. : absl::nullopt;
  200. std::vector<std::pair<int32_t, float>> category_candidates;
  201. for (const auto& category : model_output) {
  202. if (visibility_category_name &&
  203. category.class_name == *visibility_category_name) {
  204. continue;
  205. }
  206. // Assume everything else is for categories.
  207. int category_id;
  208. if (base::StringToInt(category.class_name, &category_id)) {
  209. category_candidates.emplace_back(
  210. std::make_pair(category_id, static_cast<float>(category.score)));
  211. }
  212. }
  213. // Postprocess categories.
  214. if (!model_metadata->output_postprocessing_params().has_category_params()) {
  215. // No parameters for postprocessing, so just return.
  216. return absl::nullopt;
  217. }
  218. const proto::PageTopicsCategoryPostprocessingParams category_params =
  219. model_metadata->output_postprocessing_params().category_params();
  220. // Determine the categories with the highest weights.
  221. std::sort(
  222. category_candidates.begin(), category_candidates.end(),
  223. [](const std::pair<int32_t, float>& a,
  224. const std::pair<int32_t, float>& b) { return a.second > b.second; });
  225. size_t max_categories = static_cast<size_t>(category_params.max_categories());
  226. float total_weight = 0.0;
  227. float sum_positive_scores = 0.0;
  228. absl::optional<std::pair<size_t, float>> none_idx_and_weight;
  229. std::vector<std::pair<int32_t, float>> categories;
  230. categories.reserve(max_categories);
  231. for (size_t i = 0; i < category_candidates.size() && i < max_categories;
  232. i++) {
  233. std::pair<int32_t, float> candidate = category_candidates[i];
  234. categories.push_back(candidate);
  235. total_weight += candidate.second;
  236. if (candidate.second > 0)
  237. sum_positive_scores += candidate.second;
  238. if (candidate.first == kNoneCategoryId) {
  239. none_idx_and_weight = std::make_pair(i, candidate.second);
  240. }
  241. }
  242. // Prune out categories that do not meet the minimum threshold.
  243. if (category_params.min_category_weight() > 0) {
  244. categories.erase(
  245. std::remove_if(categories.begin(), categories.end(),
  246. [&](const std::pair<int32_t, float>& category) {
  247. return category.second <
  248. category_params.min_category_weight();
  249. }),
  250. categories.end());
  251. }
  252. // Prune out none weights.
  253. if (total_weight == 0) {
  254. return absl::nullopt;
  255. }
  256. if (none_idx_and_weight) {
  257. if ((none_idx_and_weight->second / total_weight) >
  258. category_params.min_none_weight()) {
  259. // None weight is too strong.
  260. return absl::nullopt;
  261. }
  262. // None weight doesn't matter, so prune it out. Note that it may have
  263. // already been removed above if its weight was below the category min.
  264. categories.erase(
  265. std::remove_if(categories.begin(), categories.end(),
  266. [&](const std::pair<int32_t, float>& category) {
  267. return category.first == kNoneCategoryId;
  268. }),
  269. categories.end());
  270. }
  271. // Normalize category weights.
  272. float normalization_factor =
  273. sum_positive_scores > 0 ? sum_positive_scores : 1.0;
  274. categories.erase(
  275. std::remove_if(
  276. categories.begin(), categories.end(),
  277. [&](const std::pair<int32_t, float>& category) {
  278. return (category.second / normalization_factor) <
  279. category_params.min_normalized_weight_within_top_n();
  280. }),
  281. categories.end());
  282. std::vector<WeightedIdentifier> final_categories;
  283. final_categories.reserve(categories.size());
  284. for (const auto& category : categories) {
  285. // We expect the weight to be between 0 and 1.
  286. DCHECK(category.second >= 0.0 && category.second <= 1.0);
  287. final_categories.emplace_back(
  288. WeightedIdentifier(category.first, category.second));
  289. }
  290. DCHECK_LE(final_categories.size(), max_categories);
  291. return final_categories;
  292. }
  293. void PageTopicsModelExecutor::UnloadModel() {
  294. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  295. BertModelHandler::UnloadModel();
  296. override_list_ = absl::nullopt;
  297. }
  298. void PageTopicsModelExecutor::OnModelUpdated(
  299. proto::OptimizationTarget optimization_target,
  300. const ModelInfo& model_info) {
  301. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  302. BertModelHandler::OnModelUpdated(optimization_target, model_info);
  303. if (optimization_target != proto::OPTIMIZATION_TARGET_PAGE_TOPICS_V2) {
  304. return;
  305. }
  306. // New model, new override list.
  307. override_list_file_path_ = absl::nullopt;
  308. override_list_ = absl::nullopt;
  309. for (const base::FilePath& path : model_info.GetAdditionalFiles()) {
  310. DCHECK(path.IsAbsolute());
  311. if (path.BaseName() == base::FilePath(kOverrideListBasePath)) {
  312. override_list_file_path_ = path;
  313. break;
  314. }
  315. }
  316. base::UmaHistogramBoolean("OptimizationGuide.PageTopicsOverrideList.GotFile",
  317. !!override_list_file_path_);
  318. }
  319. } // namespace optimization_guide