PageRenderTime 48ms CodeModel.GetById 25ms RepoModel.GetById 0ms app.codeStats 0ms

/components/optimization_guide/core/prediction_manager.h

https://github.com/chromium/chromium
C Header | 333 lines | 154 code | 68 blank | 111 comment | 0 complexity | c6933f2a54af131ae6f2870930b92934 MD5 | raw file
Possible License(s): MPL-2.0-no-copyleft-exception, Apache-2.0, BSD-3-Clause
  1. // Copyright 2019 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #ifndef COMPONENTS_OPTIMIZATION_GUIDE_CORE_PREDICTION_MANAGER_H_
  5. #define COMPONENTS_OPTIMIZATION_GUIDE_CORE_PREDICTION_MANAGER_H_
  6. #include <memory>
  7. #include <string>
  8. #include <vector>
  9. #include "base/callback.h"
  10. #include "base/containers/flat_map.h"
  11. #include "base/containers/flat_set.h"
  12. #include "base/containers/lru_cache.h"
  13. #include "base/files/file_path.h"
  14. #include "base/memory/raw_ptr.h"
  15. #include "base/memory/weak_ptr.h"
  16. #include "base/observer_list.h"
  17. #include "base/sequence_checker.h"
  18. #include "base/time/clock.h"
  19. #include "base/timer/timer.h"
  20. #include "components/optimization_guide/core/model_enums.h"
  21. #include "components/optimization_guide/core/optimization_guide_enums.h"
  22. #include "components/optimization_guide/core/prediction_model_download_observer.h"
  23. #include "components/optimization_guide/proto/models.pb.h"
  24. #include "url/origin.h"
  25. namespace download {
  26. class BackgroundDownloadService;
  27. } // namespace download
  28. namespace network {
  29. class SharedURLLoaderFactory;
  30. } // namespace network
  31. class OptimizationGuideLogger;
  32. class PrefService;
  33. namespace optimization_guide {
  34. enum class OptimizationGuideDecision;
  35. class OptimizationGuideStore;
  36. class OptimizationTargetModelObserver;
  37. class PredictionModelDownloadManager;
  38. class PredictionModelFetcher;
  39. class ModelInfo;
  40. // A PredictionManager supported by the optimization guide that makes an
  41. // OptimizationTargetDecision by evaluating the corresponding prediction model
  42. // for an OptimizationTarget.
  43. class PredictionManager : public PredictionModelDownloadObserver {
  44. public:
  45. // BackgroundDownloadService is only available once the profile is fully
  46. // initialized and that cannot be done as part of |Initialize|. Get a provider
  47. // to retrieve the service when it is needed.
  48. using BackgroundDownloadServiceProvider =
  49. base::OnceCallback<download::BackgroundDownloadService*(void)>;
  50. // Callback to whether component updates are enabled for the browser.
  51. using ComponentUpdatesEnabledProvider = base::RepeatingCallback<bool(void)>;
  52. PredictionManager(
  53. base::WeakPtr<OptimizationGuideStore> model_and_features_store,
  54. scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
  55. PrefService* pref_service,
  56. bool off_the_record,
  57. const std::string& application_locale,
  58. const base::FilePath& models_dir_path,
  59. OptimizationGuideLogger* optimization_guide_logger,
  60. BackgroundDownloadServiceProvider background_download_service_provider,
  61. ComponentUpdatesEnabledProvider component_updates_enabled_provider);
  62. PredictionManager(const PredictionManager&) = delete;
  63. PredictionManager& operator=(const PredictionManager&) = delete;
  64. ~PredictionManager() override;
  65. // Adds an observer for updates to the model for |optimization_target|.
  66. //
  67. // It is assumed that any model retrieved this way will be passed to the
  68. // Machine Learning Service for inference.
  69. void AddObserverForOptimizationTargetModel(
  70. proto::OptimizationTarget optimization_target,
  71. const absl::optional<proto::Any>& model_metadata,
  72. OptimizationTargetModelObserver* observer);
  73. // Removes an observer for updates to the model for |optimization_target|.
  74. //
  75. // If |observer| is registered for multiple targets, |observer| must be
  76. // removed for all observed targets for in order for it to be fully
  77. // removed from receiving any calls.
  78. void RemoveObserverForOptimizationTargetModel(
  79. proto::OptimizationTarget optimization_target,
  80. OptimizationTargetModelObserver* observer);
  81. // Set the prediction model fetcher for testing.
  82. void SetPredictionModelFetcherForTesting(
  83. std::unique_ptr<PredictionModelFetcher> prediction_model_fetcher);
  84. PredictionModelFetcher* prediction_model_fetcher() const {
  85. return prediction_model_fetcher_.get();
  86. }
  87. // Set the prediction model download manager for testing.
  88. void SetPredictionModelDownloadManagerForTesting(
  89. std::unique_ptr<PredictionModelDownloadManager>
  90. prediction_model_download_manager);
  91. PredictionModelDownloadManager* prediction_model_download_manager() const {
  92. return prediction_model_download_manager_.get();
  93. }
  94. base::WeakPtr<OptimizationGuideStore> model_and_features_store() const {
  95. return model_and_features_store_;
  96. }
  97. // Return the optimization targets that are registered.
  98. base::flat_set<proto::OptimizationTarget> GetRegisteredOptimizationTargets()
  99. const;
  100. // Override |clock_| for testing.
  101. void SetClockForTesting(const base::Clock* clock);
  102. // Override the model file returned to observers for |optimization_target|.
  103. // Use |TestModelInfoBuilder| to construct the model files. For
  104. // testing purposes only.
  105. void OverrideTargetModelForTesting(
  106. proto::OptimizationTarget optimization_target,
  107. std::unique_ptr<ModelInfo> model_info);
  108. // PredictionModelDownloadObserver:
  109. void OnModelReady(const proto::PredictionModel& model) override;
  110. void OnModelDownloadStarted(
  111. proto::OptimizationTarget optimization_target) override;
  112. void OnModelDownloadFailed(
  113. proto::OptimizationTarget optimization_target) override;
  114. protected:
  115. // Process |prediction_models| to be stored in the in memory optimization
  116. // target prediction model map for immediate use and asynchronously write the
  117. // models to the model and features store to be persisted.
  118. void UpdatePredictionModels(
  119. const google::protobuf::RepeatedPtrField<proto::PredictionModel>&
  120. prediction_models);
  121. private:
  122. friend class PredictionManagerTestBase;
  123. // Called on construction to initialize the prediction model.
  124. // |background_dowload_service_provider| can provide the
  125. // BackgroundDownloadService if needed to download models.
  126. void Initialize(
  127. BackgroundDownloadServiceProvider background_dowload_service_provider);
  128. // Called to make a request to fetch models from the remote Optimization Guide
  129. // Service. Used to fetch models for the registered optimization targets.
  130. // |is_first_model_fetch| indicates whether this is the first model fetch
  131. // happening at startup, and is used to record metrics.
  132. void FetchModels(bool is_first_model_fetch);
  133. // Callback when the models have been fetched from the remote Optimization
  134. // Guide Service and are ready for parsing. Processes the prediction models in
  135. // the response and stores them for use. The metadata entry containing the
  136. // time that updates should be fetched from the remote Optimization Guide
  137. // Service is updated, even when the response is empty.
  138. void OnModelsFetched(const std::vector<proto::ModelInfo> models_request_info,
  139. absl::optional<std::unique_ptr<proto::GetModelsResponse>>
  140. get_models_response_data);
  141. // Callback run after the model and host model features store is fully
  142. // initialized. The prediction manager can load models from
  143. // the store for registered optimization targets. |store_is_ready_| is set to
  144. // true.
  145. void OnStoreInitialized(
  146. BackgroundDownloadServiceProvider background_dowload_service_provider);
  147. // Callback run after prediction models are stored in
  148. // |model_and_features_store_|.
  149. void OnPredictionModelsStored();
  150. // Load models for every target in |optimization_targets| that have not yet
  151. // been loaded from the store.
  152. void LoadPredictionModels(
  153. const base::flat_set<proto::OptimizationTarget>& optimization_targets);
  154. // Callback run after a prediction model is loaded from the store.
  155. // |prediction_model| is used to construct a PredictionModel capable of making
  156. // prediction for the appropriate |optimization_target|.
  157. void OnLoadPredictionModel(
  158. proto::OptimizationTarget optimization_target,
  159. bool record_availability_metrics,
  160. std::unique_ptr<proto::PredictionModel> prediction_model);
  161. // Callback run after a prediction model is loaded from a command-line
  162. // override.
  163. void OnPredictionModelOverrideLoaded(
  164. proto::OptimizationTarget optimization_target,
  165. std::unique_ptr<proto::PredictionModel> prediction_model);
  166. // Process loaded |model| into memory. Return true if a prediction
  167. // model object was created and successfully stored, otherwise false.
  168. bool ProcessAndStoreLoadedModel(const proto::PredictionModel& model);
  169. // Return whether the model stored in memory for |optimization_target| should
  170. // be updated based on what's currently stored and |new_version|.
  171. bool ShouldUpdateStoredModelForTarget(
  172. proto::OptimizationTarget optimization_target,
  173. int64_t new_version) const;
  174. // Updates the in-memory model file for |optimization_target| to
  175. // |prediction_model_file|.
  176. void StoreLoadedModelInfo(proto::OptimizationTarget optimization_target,
  177. std::unique_ptr<ModelInfo> prediction_model_file);
  178. // Post-processing callback invoked after processing |model|.
  179. void OnProcessLoadedModel(const proto::PredictionModel& model, bool success);
  180. // Return the time when a prediction model fetch was last attempted.
  181. base::Time GetLastFetchAttemptTime() const;
  182. // Set the last time when a prediction model fetch was last attempted to
  183. // |last_attempt_time|.
  184. void SetLastModelFetchAttemptTime(base::Time last_attempt_time);
  185. // Return the time when a prediction model fetch was last successfully
  186. // completed.
  187. base::Time GetLastFetchSuccessTime() const;
  188. // Set the last time when a fetch for prediction models last succeeded to
  189. // |last_success_time|.
  190. void SetLastModelFetchSuccessTime(base::Time last_success_time);
  191. // Schedule first fetch for models if enabled for this profile.
  192. void MaybeScheduleFirstModelFetch();
  193. // Schedule |fetch_timer_| to fire based on:
  194. // 1. The update time for models in the store and
  195. // 2. The last time a fetch attempt was made.
  196. void ScheduleModelsFetch();
  197. // Notifies observers of |optimization_target| that the model has been
  198. // updated.
  199. void NotifyObserversOfNewModel(proto::OptimizationTarget optimization_target,
  200. const ModelInfo& model_info);
  201. // A map of optimization target to the model file containing the model for the
  202. // target.
  203. base::flat_map<proto::OptimizationTarget, std::unique_ptr<ModelInfo>>
  204. optimization_target_model_info_map_;
  205. // The map from optimization targets to feature-provided metadata that have
  206. // been registered with the prediction manager.
  207. base::flat_map<proto::OptimizationTarget, absl::optional<proto::Any>>
  208. registered_optimization_targets_and_metadata_;
  209. // The map from optimization target to observers that have been registered to
  210. // receive model updates from the prediction manager.
  211. std::map<proto::OptimizationTarget,
  212. base::ObserverList<OptimizationTargetModelObserver>>
  213. registered_observers_for_optimization_targets_;
  214. // The fetcher that handles making requests to update the models and host
  215. // model features from the remote Optimization Guide Service.
  216. std::unique_ptr<PredictionModelFetcher> prediction_model_fetcher_;
  217. // The downloader that handles making requests to download the prediction
  218. // models. Can be null if model downloading is disabled.
  219. std::unique_ptr<PredictionModelDownloadManager>
  220. prediction_model_download_manager_;
  221. // TODO(crbug/1183507): Remove host model features store and all relevant
  222. // code, and deprecate the proto field too.
  223. // The optimization guide store that contains prediction models and host
  224. // model features from the remote Optimization Guide Service.
  225. base::WeakPtr<OptimizationGuideStore> model_and_features_store_;
  226. // A stored response from a model and host model features fetch used to hold
  227. // models to be stored once host model features are processed and stored.
  228. std::unique_ptr<proto::GetModelsResponse> get_models_response_data_to_store_;
  229. // The URL loader factory used for fetching model and host feature updates
  230. // from the remote Optimization Guide Service.
  231. scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory_;
  232. // The logger that plumbs the debug logs to the optimization guide
  233. // internals page. Not owned. Guaranteed to outlive |this|, since the logger
  234. // and |this| are owned by the optimization guide keyed service.
  235. raw_ptr<OptimizationGuideLogger> optimization_guide_logger_;
  236. // A reference to the PrefService for this profile. Not owned.
  237. raw_ptr<PrefService> pref_service_ = nullptr;
  238. // The repeating callback that will be used to determine if component updates
  239. // are enabled.
  240. ComponentUpdatesEnabledProvider component_updates_enabled_provider_;
  241. // Time the prediction manager got initialized.
  242. base::TimeTicks init_time_;
  243. // The timer used to schedule fetching prediction models and host model
  244. // features from the remote Optimization Guide Service.
  245. base::OneShotTimer fetch_timer_;
  246. // The clock used to schedule fetching from the remote Optimization Guide
  247. // Service.
  248. raw_ptr<const base::Clock> clock_;
  249. // Whether the |model_and_features_store_| is initialized and ready for use.
  250. bool store_is_ready_ = false;
  251. // Whether host model features have been loaded from the store and are ready
  252. // for use.
  253. bool host_model_features_loaded_ = false;
  254. // Whether the profile for this PredictionManager is off the record.
  255. bool off_the_record_ = false;
  256. // The locale of the application.
  257. std::string application_locale_;
  258. // The path to the directory containing the models.
  259. base::FilePath models_dir_path_;
  260. SEQUENCE_CHECKER(sequence_checker_);
  261. // Used to get |weak_ptr_| to self on the UI thread.
  262. base::WeakPtrFactory<PredictionManager> ui_weak_ptr_factory_{this};
  263. };
  264. } // namespace optimization_guide
  265. #endif // COMPONENTS_OPTIMIZATION_GUIDE_CORE_PREDICTION_MANAGER_H_