PageRenderTime 50ms CodeModel.GetById 25ms RepoModel.GetById 0ms app.codeStats 0ms

/examples/applications/plot_stock_market.py

http://github.com/scikit-learn/scikit-learn
Python | 261 lines | 137 code | 33 blank | 91 comment | 10 complexity | f1729ff1fac1f9ce13e049f651785355 MD5 | raw file
Possible License(s): BSD-3-Clause
  1. """
  2. =======================================
  3. Visualizing the stock market structure
  4. =======================================
  5. This example employs several unsupervised learning techniques to extract
  6. the stock market structure from variations in historical quotes.
  7. The quantity that we use is the daily variation in quote price: quotes
  8. that are linked tend to cofluctuate during a day.
  9. .. _stock_market:
  10. Learning a graph structure
  11. --------------------------
  12. We use sparse inverse covariance estimation to find which quotes are
  13. correlated conditionally on the others. Specifically, sparse inverse
  14. covariance gives us a graph, that is a list of connection. For each
  15. symbol, the symbols that it is connected too are those useful to explain
  16. its fluctuations.
  17. Clustering
  18. ----------
  19. We use clustering to group together quotes that behave similarly. Here,
  20. amongst the :ref:`various clustering techniques <clustering>` available
  21. in the scikit-learn, we use :ref:`affinity_propagation` as it does
  22. not enforce equal-size clusters, and it can choose automatically the
  23. number of clusters from the data.
  24. Note that this gives us a different indication than the graph, as the
  25. graph reflects conditional relations between variables, while the
  26. clustering reflects marginal properties: variables clustered together can
  27. be considered as having a similar impact at the level of the full stock
  28. market.
  29. Embedding in 2D space
  30. ---------------------
  31. For visualization purposes, we need to lay out the different symbols on a
  32. 2D canvas. For this we use :ref:`manifold` techniques to retrieve 2D
  33. embedding.
  34. Visualization
  35. -------------
  36. The output of the 3 models are combined in a 2D graph where nodes
  37. represents the stocks and edges the:
  38. - cluster labels are used to define the color of the nodes
  39. - the sparse covariance model is used to display the strength of the edges
  40. - the 2D embedding is used to position the nodes in the plan
  41. This example has a fair amount of visualization-related code, as
  42. visualization is crucial here to display the graph. One of the challenge
  43. is to position the labels minimizing overlap. For this we use an
  44. heuristic based on the direction of the nearest neighbor along each
  45. axis.
  46. """
  47. # Author: Gael Varoquaux gael.varoquaux@normalesup.org
  48. # License: BSD 3 clause
  49. import sys
  50. import numpy as np
  51. import matplotlib.pyplot as plt
  52. from matplotlib.collections import LineCollection
  53. import pandas as pd
  54. from sklearn import cluster, covariance, manifold
  55. print(__doc__)
  56. # #############################################################################
  57. # Retrieve the data from Internet
  58. # The data is from 2003 - 2008. This is reasonably calm: (not too long ago so
  59. # that we get high-tech firms, and before the 2008 crash). This kind of
  60. # historical data can be obtained for from APIs like the quandl.com and
  61. # alphavantage.co ones.
  62. symbol_dict = {
  63. 'TOT': 'Total',
  64. 'XOM': 'Exxon',
  65. 'CVX': 'Chevron',
  66. 'COP': 'ConocoPhillips',
  67. 'VLO': 'Valero Energy',
  68. 'MSFT': 'Microsoft',
  69. 'IBM': 'IBM',
  70. 'TWX': 'Time Warner',
  71. 'CMCSA': 'Comcast',
  72. 'CVC': 'Cablevision',
  73. 'YHOO': 'Yahoo',
  74. 'DELL': 'Dell',
  75. 'HPQ': 'HP',
  76. 'AMZN': 'Amazon',
  77. 'TM': 'Toyota',
  78. 'CAJ': 'Canon',
  79. 'SNE': 'Sony',
  80. 'F': 'Ford',
  81. 'HMC': 'Honda',
  82. 'NAV': 'Navistar',
  83. 'NOC': 'Northrop Grumman',
  84. 'BA': 'Boeing',
  85. 'KO': 'Coca Cola',
  86. 'MMM': '3M',
  87. 'MCD': 'McDonald\'s',
  88. 'PEP': 'Pepsi',
  89. 'K': 'Kellogg',
  90. 'UN': 'Unilever',
  91. 'MAR': 'Marriott',
  92. 'PG': 'Procter Gamble',
  93. 'CL': 'Colgate-Palmolive',
  94. 'GE': 'General Electrics',
  95. 'WFC': 'Wells Fargo',
  96. 'JPM': 'JPMorgan Chase',
  97. 'AIG': 'AIG',
  98. 'AXP': 'American express',
  99. 'BAC': 'Bank of America',
  100. 'GS': 'Goldman Sachs',
  101. 'AAPL': 'Apple',
  102. 'SAP': 'SAP',
  103. 'CSCO': 'Cisco',
  104. 'TXN': 'Texas Instruments',
  105. 'XRX': 'Xerox',
  106. 'WMT': 'Wal-Mart',
  107. 'HD': 'Home Depot',
  108. 'GSK': 'GlaxoSmithKline',
  109. 'PFE': 'Pfizer',
  110. 'SNY': 'Sanofi-Aventis',
  111. 'NVS': 'Novartis',
  112. 'KMB': 'Kimberly-Clark',
  113. 'R': 'Ryder',
  114. 'GD': 'General Dynamics',
  115. 'RTN': 'Raytheon',
  116. 'CVS': 'CVS',
  117. 'CAT': 'Caterpillar',
  118. 'DD': 'DuPont de Nemours'}
  119. symbols, names = np.array(sorted(symbol_dict.items())).T
  120. quotes = []
  121. for symbol in symbols:
  122. print('Fetching quote history for %r' % symbol, file=sys.stderr)
  123. url = ('https://raw.githubusercontent.com/scikit-learn/examples-data/'
  124. 'master/financial-data/{}.csv')
  125. quotes.append(pd.read_csv(url.format(symbol)))
  126. close_prices = np.vstack([q['close'] for q in quotes])
  127. open_prices = np.vstack([q['open'] for q in quotes])
  128. # The daily variations of the quotes are what carry most information
  129. variation = close_prices - open_prices
  130. # #############################################################################
  131. # Learn a graphical structure from the correlations
  132. edge_model = covariance.GraphicalLassoCV()
  133. # standardize the time series: using correlations rather than covariance
  134. # is more efficient for structure recovery
  135. X = variation.copy().T
  136. X /= X.std(axis=0)
  137. edge_model.fit(X)
  138. # #############################################################################
  139. # Cluster using affinity propagation
  140. _, labels = cluster.affinity_propagation(edge_model.covariance_)
  141. n_labels = labels.max()
  142. for i in range(n_labels + 1):
  143. print('Cluster %i: %s' % ((i + 1), ', '.join(names[labels == i])))
  144. # #############################################################################
  145. # Find a low-dimension embedding for visualization: find the best position of
  146. # the nodes (the stocks) on a 2D plane
  147. # We use a dense eigen_solver to achieve reproducibility (arpack is
  148. # initiated with random vectors that we don't control). In addition, we
  149. # use a large number of neighbors to capture the large-scale structure.
  150. node_position_model = manifold.LocallyLinearEmbedding(
  151. n_components=2, eigen_solver='dense', n_neighbors=6)
  152. embedding = node_position_model.fit_transform(X.T).T
  153. # #############################################################################
  154. # Visualization
  155. plt.figure(1, facecolor='w', figsize=(10, 8))
  156. plt.clf()
  157. ax = plt.axes([0., 0., 1., 1.])
  158. plt.axis('off')
  159. # Display a graph of the partial correlations
  160. partial_correlations = edge_model.precision_.copy()
  161. d = 1 / np.sqrt(np.diag(partial_correlations))
  162. partial_correlations *= d
  163. partial_correlations *= d[:, np.newaxis]
  164. non_zero = (np.abs(np.triu(partial_correlations, k=1)) > 0.02)
  165. # Plot the nodes using the coordinates of our embedding
  166. plt.scatter(embedding[0], embedding[1], s=100 * d ** 2, c=labels,
  167. cmap=plt.cm.nipy_spectral)
  168. # Plot the edges
  169. start_idx, end_idx = np.where(non_zero)
  170. # a sequence of (*line0*, *line1*, *line2*), where::
  171. # linen = (x0, y0), (x1, y1), ... (xm, ym)
  172. segments = [[embedding[:, start], embedding[:, stop]]
  173. for start, stop in zip(start_idx, end_idx)]
  174. values = np.abs(partial_correlations[non_zero])
  175. lc = LineCollection(segments,
  176. zorder=0, cmap=plt.cm.hot_r,
  177. norm=plt.Normalize(0, .7 * values.max()))
  178. lc.set_array(values)
  179. lc.set_linewidths(15 * values)
  180. ax.add_collection(lc)
  181. # Add a label to each node. The challenge here is that we want to
  182. # position the labels to avoid overlap with other labels
  183. for index, (name, label, (x, y)) in enumerate(
  184. zip(names, labels, embedding.T)):
  185. dx = x - embedding[0]
  186. dx[index] = 1
  187. dy = y - embedding[1]
  188. dy[index] = 1
  189. this_dx = dx[np.argmin(np.abs(dy))]
  190. this_dy = dy[np.argmin(np.abs(dx))]
  191. if this_dx > 0:
  192. horizontalalignment = 'left'
  193. x = x + .002
  194. else:
  195. horizontalalignment = 'right'
  196. x = x - .002
  197. if this_dy > 0:
  198. verticalalignment = 'bottom'
  199. y = y + .002
  200. else:
  201. verticalalignment = 'top'
  202. y = y - .002
  203. plt.text(x, y, name, size=10,
  204. horizontalalignment=horizontalalignment,
  205. verticalalignment=verticalalignment,
  206. bbox=dict(facecolor='w',
  207. edgecolor=plt.cm.nipy_spectral(label / float(n_labels)),
  208. alpha=.6))
  209. plt.xlim(embedding[0].min() - .15 * embedding[0].ptp(),
  210. embedding[0].max() + .10 * embedding[0].ptp(),)
  211. plt.ylim(embedding[1].min() - .03 * embedding[1].ptp(),
  212. embedding[1].max() + .03 * embedding[1].ptp())
  213. plt.show()