PageRenderTime 50ms CodeModel.GetById 17ms RepoModel.GetById 0ms app.codeStats 0ms

/demos/web/server.py

https://gitlab.com/oytunistrator/openface
Python | 360 lines | 281 code | 53 blank | 26 comment | 55 complexity | b5ee1e018e0e7a7821c8b16a358fea1d MD5 | raw file
  1. #!/usr/bin/env python2
  2. #
  3. # Copyright 2015 Carnegie Mellon University
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import os
  17. import sys
  18. fileDir = os.path.dirname(os.path.realpath(__file__))
  19. sys.path.append(os.path.join(fileDir, "..", ".."))
  20. from autobahn.twisted.websocket import WebSocketServerProtocol, \
  21. WebSocketServerFactory
  22. from twisted.python import log
  23. from twisted.internet import reactor
  24. import argparse
  25. import cv2
  26. import imagehash
  27. import json
  28. from PIL import Image
  29. import numpy as np
  30. import os
  31. import StringIO
  32. import urllib
  33. import base64
  34. from sklearn.decomposition import PCA
  35. from sklearn.grid_search import GridSearchCV
  36. from sklearn.manifold import TSNE
  37. from sklearn.svm import SVC
  38. import matplotlib.pyplot as plt
  39. import matplotlib.cm as cm
  40. import openface
  41. import tempfile
  42. modelDir = os.path.join(fileDir, '..', '..', 'models')
  43. dlibModelDir = os.path.join(modelDir, 'dlib')
  44. openfaceModelDir = os.path.join(modelDir, 'openface')
  45. parser = argparse.ArgumentParser()
  46. parser.add_argument('--dlibFaceMean', type=str, help="Path to dlib's face predictor.",
  47. default=os.path.join(dlibModelDir, "mean.csv"))
  48. parser.add_argument('--dlibFacePredictor', type=str, help="Path to dlib's face predictor.",
  49. default=os.path.join(dlibModelDir, "shape_predictor_68_face_landmarks.dat"))
  50. parser.add_argument('--dlibRoot', type=str,
  51. default=os.path.expanduser(
  52. "~/src/dlib-18.16/python_examples"),
  53. help="dlib directory with the dlib.so Python library.")
  54. parser.add_argument('--networkModel', type=str, help="Path to Torch network model.",
  55. default=os.path.join(openfaceModelDir, 'nn4.v1.t7'))
  56. parser.add_argument('--imgDim', type=int,
  57. help="Default image dimension.", default=96)
  58. parser.add_argument('--cuda', type=bool, default=False)
  59. parser.add_argument('--unknown', type=bool, default=False,
  60. help='Try to predict unknown people')
  61. args = parser.parse_args()
  62. sys.path.append(args.dlibRoot)
  63. import dlib
  64. from openface.alignment import NaiveDlib # Depends on dlib.
  65. align = NaiveDlib(args.dlibFaceMean, args.dlibFacePredictor)
  66. net = openface.TorchWrap(args.networkModel, imgDim=args.imgDim, cuda=args.cuda)
  67. class Face:
  68. def __init__(self, rep, identity):
  69. self.rep = rep
  70. self.identity = identity
  71. def __repr__(self):
  72. return "{{id: {}, rep[0:5]: {}}}".format(
  73. str(self.identity),
  74. self.rep[0:5]
  75. )
  76. class OpenFaceServerProtocol(WebSocketServerProtocol):
  77. def __init__(self):
  78. self.images = {}
  79. self.training = True
  80. self.people = []
  81. self.svm = None
  82. if args.unknown:
  83. self.unknownImgs = np.load("./examples/web/unknown.npy")
  84. def onConnect(self, request):
  85. print("Client connecting: {0}".format(request.peer))
  86. self.training = True
  87. def onOpen(self):
  88. print("WebSocket connection open.")
  89. def onMessage(self, payload, isBinary):
  90. raw = payload.decode('utf8')
  91. msg = json.loads(raw)
  92. print("Received {} message of length {}.".format(
  93. msg['type'], len(raw)))
  94. if msg['type'] == "ALL_STATE":
  95. self.loadState(msg['images'], msg['training'], msg['people'])
  96. elif msg['type'] == "NULL":
  97. self.sendMessage('{"type": "NULL"}')
  98. elif msg['type'] == "FRAME":
  99. self.processFrame(msg['dataURL'], msg['identity'])
  100. self.sendMessage('{"type": "PROCESSED"}')
  101. elif msg['type'] == "TRAINING":
  102. self.training = msg['val']
  103. if not self.training:
  104. self.trainSVM()
  105. elif msg['type'] == "ADD_PERSON":
  106. self.people.append(msg['val'].encode('ascii', 'ignore'))
  107. print(self.people)
  108. elif msg['type'] == "UPDATE_IDENTITY":
  109. h = msg['hash'].encode('ascii', 'ignore')
  110. if h in self.images:
  111. self.images[h].identity = msg['idx']
  112. if not self.training:
  113. self.trainSVM()
  114. else:
  115. print("Image not found.")
  116. elif msg['type'] == "REMOVE_IMAGE":
  117. h = msg['hash'].encode('ascii', 'ignore')
  118. if h in self.images:
  119. del self.images[h]
  120. if not self.training:
  121. self.trainSVM()
  122. else:
  123. print("Image not found.")
  124. elif msg['type'] == 'REQ_TSNE':
  125. self.sendTSNE(msg['people'])
  126. else:
  127. print("Warning: Unknown message type: {}".format(msg['type']))
  128. def onClose(self, wasClean, code, reason):
  129. print("WebSocket connection closed: {0}".format(reason))
  130. def loadState(self, jsImages, training, jsPeople):
  131. self.training = training
  132. for jsImage in jsImages:
  133. h = jsImage['hash'].encode('ascii', 'ignore')
  134. self.images[h] = Face(np.array(jsImage['representation']),
  135. jsImage['identity'])
  136. for jsPerson in jsPeople:
  137. self.people.append(jsPerson.encode('ascii', 'ignore'))
  138. if not training:
  139. self.trainSVM()
  140. def getData(self):
  141. X = []
  142. y = []
  143. for img in self.images.values():
  144. X.append(img.rep)
  145. y.append(img.identity)
  146. numIdentities = len(set(y + [-1])) - 1
  147. if numIdentities == 0:
  148. return None
  149. if args.unknown:
  150. numUnknown = y.count(-1)
  151. numIdentified = len(y) - numUnknown
  152. numUnknownAdd = (numIdentified / numIdentities) - numUnknown
  153. if numUnknownAdd > 0:
  154. print("+ Augmenting with {} unknown images.".format(numUnknownAdd))
  155. for rep in self.unknownImgs[:numUnknownAdd]:
  156. # print(rep)
  157. X.append(rep)
  158. y.append(-1)
  159. X = np.vstack(X)
  160. y = np.array(y)
  161. return (X, y)
  162. def sendTSNE(self, people):
  163. d = self.getData()
  164. if d is None:
  165. return
  166. else:
  167. (X, y) = d
  168. X_pca = PCA(n_components=50).fit_transform(X, X)
  169. tsne = TSNE(n_components=2, init='random', random_state=0)
  170. X_r = tsne.fit_transform(X_pca)
  171. yVals = list(np.unique(y))
  172. colors = cm.rainbow(np.linspace(0, 1, len(yVals)))
  173. # print(yVals)
  174. plt.figure()
  175. for c, i in zip(colors, yVals):
  176. name = "Unknown" if i == -1 else people[i]
  177. plt.scatter(X_r[y == i, 0], X_r[y == i, 1], c=c, label=name)
  178. plt.legend()
  179. imgdata = StringIO.StringIO()
  180. plt.savefig(imgdata, format='png')
  181. imgdata.seek(0)
  182. content = 'data:image/png;base64,' + \
  183. urllib.quote(base64.b64encode(imgdata.buf))
  184. msg = {
  185. "type": "TSNE_DATA",
  186. "content": content
  187. }
  188. self.sendMessage(json.dumps(msg))
  189. def trainSVM(self):
  190. print("+ Training SVM on {} labeled images.".format(len(self.images)))
  191. d = self.getData()
  192. if d is None:
  193. self.svm = None
  194. return
  195. else:
  196. (X, y) = d
  197. numIdentities = len(set(y + [-1]))
  198. if numIdentities <= 1:
  199. return
  200. param_grid = [
  201. {'C': [1, 10, 100, 1000],
  202. 'kernel': ['linear']},
  203. {'C': [1, 10, 100, 1000],
  204. 'gamma': [0.001, 0.0001],
  205. 'kernel': ['rbf']}
  206. ]
  207. self.svm = GridSearchCV(SVC(C=1), param_grid, cv=5).fit(X, y)
  208. def processFrame(self, dataURL, identity):
  209. head = "data:image/jpeg;base64,"
  210. assert(dataURL.startswith(head))
  211. imgdata = base64.b64decode(dataURL[len(head):])
  212. imgF = StringIO.StringIO()
  213. imgF.write(imgdata)
  214. imgF.seek(0)
  215. img = Image.open(imgF)
  216. buf = np.fliplr(np.asarray(img))
  217. rgbFrame = np.zeros((300, 400, 3), dtype=np.uint8)
  218. rgbFrame[:, :, 0] = buf[:, :, 2]
  219. rgbFrame[:, :, 1] = buf[:, :, 1]
  220. rgbFrame[:, :, 2] = buf[:, :, 0]
  221. if not self.training:
  222. annotatedFrame = np.copy(buf)
  223. # cv2.imshow('frame', rgbFrame)
  224. # if cv2.waitKey(1) & 0xFF == ord('q'):
  225. # return
  226. identities = []
  227. # bbs = align.getAllFaceBoundingBoxes(rgbFrame)
  228. bb = align.getLargestFaceBoundingBox(rgbFrame)
  229. bbs = [bb] if bb is not None else []
  230. for bb in bbs:
  231. # print(len(bbs))
  232. alignedFace = align.alignImg("affine", 96, rgbFrame, bb)
  233. if alignedFace is None:
  234. continue
  235. phash = str(imagehash.phash(Image.fromarray(alignedFace)))
  236. if phash in self.images:
  237. identity = self.images[phash].identity
  238. else:
  239. rep = net.forwardImage(alignedFace)
  240. # print(rep)
  241. if self.training:
  242. self.images[phash] = Face(rep, identity)
  243. # TODO: Transferring as a string is suboptimal.
  244. # content = [str(x) for x in cv2.resize(alignedFace, (0,0),
  245. # fx=0.5, fy=0.5).flatten()]
  246. content = [str(x) for x in alignedFace.flatten()]
  247. msg = {
  248. "type": "NEW_IMAGE",
  249. "hash": phash,
  250. "content": content,
  251. "identity": identity,
  252. "representation": rep.tolist()
  253. }
  254. self.sendMessage(json.dumps(msg))
  255. else:
  256. if len(self.people) == 0:
  257. identity = -1
  258. elif len(self.people) == 1:
  259. identity = 0
  260. elif self.svm:
  261. identity = self.svm.predict(rep)[0]
  262. else:
  263. print("hhh")
  264. identity = -1
  265. if identity not in identities:
  266. identities.append(identity)
  267. if not self.training:
  268. bl = (bb.left(), bb.bottom())
  269. tr = (bb.right(), bb.top())
  270. cv2.rectangle(annotatedFrame, bl, tr, color=(153, 255, 204),
  271. thickness=3)
  272. if identity == -1:
  273. if len(self.people) == 1:
  274. name = self.people[0]
  275. else:
  276. name = "Unknown"
  277. else:
  278. name = self.people[identity]
  279. cv2.putText(annotatedFrame, name, (bb.left(), bb.top() - 10),
  280. cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.75,
  281. color=(152, 255, 204), thickness=2)
  282. if not self.training:
  283. msg = {
  284. "type": "IDENTITIES",
  285. "identities": identities
  286. }
  287. self.sendMessage(json.dumps(msg))
  288. plt.figure()
  289. plt.imshow(annotatedFrame)
  290. plt.xticks([])
  291. plt.yticks([])
  292. imgdata = StringIO.StringIO()
  293. plt.savefig(imgdata, format='png')
  294. imgdata.seek(0)
  295. content = 'data:image/png;base64,' + \
  296. urllib.quote(base64.b64encode(imgdata.buf))
  297. msg = {
  298. "type": "ANNOTATED",
  299. "content": content
  300. }
  301. self.sendMessage(json.dumps(msg))
  302. if __name__ == '__main__':
  303. log.startLogging(sys.stdout)
  304. factory = WebSocketServerFactory("ws://localhost:9000", debug=False)
  305. factory.protocol = OpenFaceServerProtocol
  306. reactor.listenTCP(9000, factory)
  307. reactor.run()