Source code for owtf.proxy.proxy

"""
owtf.proxy.proxy
~~~~~~~~~~~~~~~~

Inbound Proxy Module developed by Bharadwaj Machiraju (blog.tunnelshade.in) as a part of Google Summer of Code 2013.
"""
import datetime
import logging
import socket
import ssl
import threading
import time
import select
import os

import pycurl
import tornado.curl_httpclient
import tornado.escape
import tornado.gen
import tornado.httpclient
import tornado.httpserver
import tornado.httputil
import tornado.ioloop
import tornado.iostream
import tornado.options
import tornado.template
import tornado.web
import tornado.websocket

from owtf.proxy.cache_handler import CacheHandler
from owtf.proxy.socket_wrapper import starttls
from owtf.proxy.interceptor_manager import InterceptorManager
from owtf.proxy.live_interceptor import LiveInterceptor
from owtf.utils.strings import utf8, to_str

# Set up logger for proxy module
logger = logging.getLogger(__name__)

# Set up request/response logging
REQUEST_LOG_FILE = "/tmp/owtf/request_response.log"
ENABLE_REQUEST_LOGGING = True  # Set to False to disable logging entirely
# To disable logging entirely, change the line above to:
# ENABLE_REQUEST_LOGGING = False

MAX_LOG_ENTRIES_PER_MINUTE = 100  # Limit logging rate

# Live interception timeout configuration
LIVE_INTERCEPTION_TIMEOUT = 30  # Timeout in seconds for live interception decisions
LIVE_INTERCEPTION_DELAY = 0.1  # Delay in seconds between polling for live interception decisions

request_logger = logging.getLogger("owtf_requests")
request_logger.setLevel(logging.INFO)

# Track logging rate
log_entries_this_minute = 0
last_minute_reset = time.time()


[docs] def disable_request_logging(): """Disable request logging to prevent disk space issues""" global ENABLE_REQUEST_LOGGING ENABLE_REQUEST_LOGGING = False logger.warning("Request logging has been disabled")
[docs] def enable_request_logging(): """Re-enable request logging""" global ENABLE_REQUEST_LOGGING ENABLE_REQUEST_LOGGING = True logger.info("Request logging has been enabled")
[docs] def set_logging_rate_limit(entries_per_minute): """Set the maximum number of log entries per minute""" global MAX_LOG_ENTRIES_PER_MINUTE MAX_LOG_ENTRIES_PER_MINUTE = entries_per_minute logger.info(f"Logging rate limit set to {entries_per_minute} entries per minute")
[docs] def cleanup_large_log_file(): """Clean up log file if it's too large""" try: if os.path.exists(REQUEST_LOG_FILE): file_size = os.path.getsize(REQUEST_LOG_FILE) if file_size > 50 * 1024 * 1024: # If larger than 50MB logger.warning(f"Log file is {file_size / (1024*1024):.1f}MB, truncating to prevent disk space issues") # Truncate the file to keep only the last 1MB with open(REQUEST_LOG_FILE, "r+b") as f: f.seek(-1024 * 1024, 2) # Go to 1MB from end f.truncate() f.write(b"\n--- LOG FILE TRUNCATED DUE TO SIZE ---\n") except Exception as e: logger.error(f"Error cleaning up log file: {e}")
# Create log directory if it doesn't exist try: os.makedirs(os.path.dirname(REQUEST_LOG_FILE), exist_ok=True) except Exception as e: logger.error(f"Error creating log directory: {e}") # Clean up large log file on startup cleanup_large_log_file() # Create file handler for request logging with rotation if not request_logger.handlers: from logging.handlers import RotatingFileHandler file_handler = RotatingFileHandler( REQUEST_LOG_FILE, maxBytes=10 * 1024 * 1024, backupCount=5 # 10MB max file size # Keep 5 backup files ) file_handler.setLevel(logging.INFO) formatter = logging.Formatter("%(asctime)s - %(message)s") file_handler.setFormatter(formatter) request_logger.addHandler(file_handler) request_logger.propagate = False
[docs] def log_request(request, method, url, headers=None, body=None, is_https=False, is_response=False): """Log intercepted request/response details to file""" global log_entries_this_minute, last_minute_reset # Check if logging is disabled if not ENABLE_REQUEST_LOGGING: return # Rate limiting current_time = time.time() if current_time - last_minute_reset >= 60: log_entries_this_minute = 0 last_minute_reset = current_time if log_entries_this_minute >= MAX_LOG_ENTRIES_PER_MINUTE: return # Skip logging if rate limit exceeded log_entries_this_minute += 1 try: timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") protocol = "HTTPS" if is_https else "HTTP" direction = "RESPONSE" if is_response else "REQUEST" log_entry = f"[{timestamp}] {protocol} {direction} {method} {url}\n" if headers: # Limit header logging to avoid huge logs header_dict = dict(headers) if len(str(header_dict)) > 2000: log_entry += f"Headers: <{len(header_dict)} headers, truncated>\n" else: log_entry += f"Headers: {header_dict}\n" if body: # Limit body logging to prevent huge logs if isinstance(body, bytes): body_size = len(body) if body_size > 500: # Reduced from 1000 to 500 log_entry += f"Body: <{body_size} bytes, first 500: {body[:500]}>\n" else: log_entry += f"Body: {body}\n" else: body_str = str(body) if len(body_str) > 500: # Reduced from 1000 to 500 log_entry += f"Body: <{len(body_str)} chars, first 500: {body_str[:500]}>\n" else: log_entry += f"Body: {body_str}\n" log_entry += "-" * 80 + "\n" # Write to file with open(REQUEST_LOG_FILE, "a", encoding="utf-8") as f: f.write(log_entry) except Exception as e: logger.error(f"Error logging request/response: {e}")
[docs] def log_response(status_code, url, headers=None, body=None, is_https=False): """Log HTTP response details""" log_request(None, f"HTTP/{status_code}", url, headers, body, is_https, True)
[docs] def prepare_curl_callback(curl): curl.setopt(pycurl.PROXYTYPE, pycurl.PROXYTYPE_SOCKS5)
[docs] class ProxyHandler(tornado.web.RequestHandler): """This RequestHandler processes all the requests that the application received.""" SUPPORTED_METHODS = [ "GET", "POST", "CONNECT", "HEAD", "PUT", "DELETE", "OPTIONS", "TRACE", ] server = None restricted_request_headers = None restricted_response_headers = None def __init__(self, *args, **kwargs): """Initialize the proxy handler with interceptor support.""" super().__init__(*args, **kwargs) # Initialize interceptor manager if not already done if not hasattr(self.application, "interceptor_manager"): self.application.interceptor_manager = InterceptorManager() logger.info("Initialized interceptor manager for proxy handler") # Initialize live interceptor if not already done if not hasattr(self.application, "live_interceptor"): self.application.live_interceptor = LiveInterceptor() logger.info("Initialized live interceptor for proxy handler") def __new__(cls, application, request, **kwargs): """ .. note:: http://stackoverflow.com/questions/3209233/how-to-replace-an-instance-in-init-with-a-different-object Based on upgrade header, websocket request handler must be used """ try: if request.headers["Upgrade"].lower() == "websocket": return CustomWebSocketHandler(application, request, **kwargs) except KeyError: pass return tornado.web.RequestHandler.__new__(cls)
[docs] def set_default_headers(self): """Automatically called by Tornado, and is used to remove "Server" header set by tornado :return: None :rtype: None """ del self._headers["Server"]
[docs] def set_status(self, status_code, reason=None): """Sets the status code for our response. Overriding is done so as to handle unknown response codes gracefully. :param status_code: status code to set :type status_code: `int` :param reason: Status code reason :type reason: `str` :return: None :rtype: None """ self._status_code = status_code if reason is not None: self._reason = tornado.escape.native_str(reason) else: try: self._reason = tornado.httputil.responses[status_code] except KeyError: self._reason = tornado.escape.native_str("Server Not Found")
[docs] def finish_response(self, response): """Write a new response and cache it :param response: :type response: :return: None :rtype: None """ # Apply response interceptors try: if hasattr(self.application, "interceptor_manager"): response = self.application.interceptor_manager.intercept_response(response) logger.debug("Applied response interceptors") except Exception as e: logger.error(f"Error applying response interceptors: {e}") self.set_status(response.code) for header, value in response.headers.get_all(): if header == "Set-Cookie": self.add_header(header, value) else: if header not in self.restricted_response_headers: self.set_header(header, value) # Log the response with body response_body = getattr(self.request, "response_buffer", "") or getattr(response, "body", "") or "" log_response( response.code, getattr(self.request, "url", ""), response.headers, response_body, getattr(self.request, "protocol", "http") == "https", ) self.finish()
[docs] def handle_data_chunk(self, data): """Callback when a small chunk is received. :param data: Data to write :type data: `str` :return: None :rtype: None """ if data and hasattr(self.request, "response_buffer"): self.write(data) if hasattr(self.request, "response_buffer"): self.request.response_buffer += to_str(data)
[docs] @tornado.gen.coroutine def get(self): """Handle all requests except the connect request. Once ssl stream is formed between browser and proxy, the requests are then processed by this function. :return: None :rtype: None """ # The flow starts here if not hasattr(self.request, "local_timestamp"): self.request.local_timestamp = datetime.datetime.now() if not hasattr(self.request, "response_buffer"): self.request.response_buffer = "" # The requests that come through ssl streams are relative requests, so transparent proxying is required. The # following snippet decides the url that should be passed to the async client if ( hasattr(self.request, "uri") and self.request.uri and hasattr(self.request, "protocol") and self.request.uri.startswith(self.request.protocol, 0) ): # Normal Proxy Request. if not hasattr(self.request, "url"): self.request.url = self.request.uri else: # Transparent Proxy Request. if not hasattr(self.request, "url"): self.request.url = "{!s}://{!s}".format( getattr(self.request, "protocol", "http"), getattr(self.request, "host", "") ) if ( hasattr(self.request, "uri") and self.request.uri and self.request.uri != "/" ): # Add uri only if needed. self.request.url += self.request.uri # Log the intercepted request is_https = getattr(self.request, "protocol", "http") == "https" log_request( self.request, getattr(self.request, "method", "GET"), getattr(self.request, "url", ""), getattr(self.request, "headers", {}), getattr(self.request, "body", ""), is_https, ) # Apply request interceptors try: if hasattr(self.application, "interceptor_manager"): self.request = self.application.interceptor_manager.intercept_request(self.request) logger.debug("Applied request interceptors") except Exception as e: logger.error(f"Error applying request interceptors: {e}") # Check for live interception try: if hasattr(self.application, "live_interceptor"): # Check if this request should be intercepted method = getattr(self.request, "method", "GET") url = getattr(self.request, "url", "") protocol = getattr(self.request, "protocol", "http") headers = dict(self.request.headers) body = getattr(self.request, "body", "") or "" request_id, should_wait = self.application.live_interceptor.intercept_request( method, url, headers, body, protocol ) if should_wait: # Store the request ID for later decision self.request.intercept_id = request_id logger.info(f"Request intercepted for live modification: {request_id}") # Wait for user decision (with timeout) start_time = time.time() while time.time() - start_time < LIVE_INTERCEPTION_TIMEOUT: # Timeout for live interception decisions decision = self.application.live_interceptor.get_decision(request_id) if decision: if decision.value == "drop": logger.info(f"Request {request_id} dropped by user") return # Drop the request elif decision.value == "modify": # Apply modifications req = self.application.live_interceptor.pending_requests.get(request_id) if req and req.modified_headers: for key, value in req.modified_headers.items(): self.request.headers[key] = value if req and req.modified_body is not None: self.request.body = req.modified_body logger.info(f"Request {request_id} modified by user") # Clean up self.application.live_interceptor.cleanup_request(request_id) break time.sleep(LIVE_INTERCEPTION_DELAY) # Small delay to avoid busy waiting else: # Timeout - auto-forward logger.info(f"Request {request_id} timed out, auto-forwarding") self.application.live_interceptor.cleanup_request(request_id) except Exception as e: logger.error(f"Error in live interception: {e}") # This block here checks for already cached response and if present returns one self.cache_handler = CacheHandler( self.application.cache_dir, self.request, self.application.cookie_regex, self.application.cookie_blacklist, ) # Fix for tornado.gen.Task compatibility try: # For newer Tornado versions, use the callback directly self.cache_handler.calculate_hash() except TypeError: # For older Tornado versions, use Task yield tornado.gen.Task(self.cache_handler.calculate_hash) self.cached_response = self.cache_handler.load() if self.cached_response: if self.cached_response.body: self.write(self.cached_response.body) self.finish_response(self.cached_response) else: # Request header cleaning for header in self.restricted_request_headers: try: del self.request.headers[header] except BaseException: continue # HTTP auth if exists http_auth_username = None http_auth_password = None http_auth_mode = None if self.application.http_auth: host = self.request.host # If default ports are not provided, they are added if ":" not in self.request.host: default_ports = {"http": "80", "https": "443"} if self.request.protocol in default_ports: host = "{!s}:{!s}".format(self.request.host, default_ports[self.request.protocol]) # Check if auth is provided for that host try: index = self.application.http_auth_hosts.index(host) http_auth_username = self.application.http_auth_usernames[index] http_auth_password = self.application.http_auth_passwords[index] http_auth_mode = self.application.http_auth_modes[index] except ValueError: pass # pycurl is needed for curl client async_client = tornado.curl_httpclient.CurlAsyncHTTPClient() # httprequest object is created and then passed to async client with a callback success_response = False # is used to check the response in the botnet mode while not success_response: # httprequest object is created and then passed to async client with a callback callback = None if self.application.outbound_proxy_type == "socks": callback = prepare_curl_callback # socks callback function. body = self.request.body or None request = tornado.httpclient.HTTPRequest( url=self.request.url, method=self.request.method, body=body, headers=self.request.headers, auth_username=http_auth_username, auth_password=http_auth_password, auth_mode=http_auth_mode, follow_redirects=False, use_gzip=True, streaming_callback=self.handle_data_chunk, header_callback=None, proxy_host=self.application.outbound_ip, proxy_port=self.application.outbound_port, proxy_username=self.application.outbound_username, proxy_password=self.application.outbound_password, allow_nonstandard_methods=True, prepare_curl_callback=callback, validate_cert=False, ) try: # Fix for tornado.gen.Task compatibility try: response = yield async_client.fetch(request) except AttributeError: # For older Tornado versions, use Task response = yield tornado.gen.Task(async_client.fetch, request) except Exception: response = None pass # Request retries for i in range(0, 3): if response is None or response.code in [408, 599]: self.request.response_buffer = "" try: response = yield async_client.fetch(request) except AttributeError: # For older Tornado versions, use Task response = yield tornado.gen.Task(async_client.fetch, request) else: success_response = True break self.finish_response(response) # Cache the response after finishing the response, so caching time is not included in response time self.cache_handler.dump(response)
head = get post = get put = get delete = get options = get trace = get
[docs] @tornado.gen.coroutine def connect(self): """Gets called when a connect request is received. * The host and port are obtained from the request uri * SSL interception is performed by terminating client SSL and establishing upstream SSL * An OK response is written back to client * Decrypted data is forwarded bidirectionally between client and server :return: None :rtype: None """ host, port = self.request.uri.split(":") port = int(port) # Log the CONNECT request (HTTPS interception) log_request(self.request, "CONNECT", f"{host}:{port}", self.request.headers, None, True) # This is HTTPS try: # Get the client stream client_stream = self.request.connection.stream logger.info("[MITM] Received CONNECT for %s:%d", host, port) # Send success response to establish the tunnel client_stream.write(b"HTTP/1.1 200 Connection established\r\n\r\n") self._finished = True logger.info("[MITM] Sent 200 Connection established to client for %s:%d", host, port) # Set up SSL termination for client connection def ssl_client_success(ssl_client_socket): try: logger.info("[MITM] SSL handshake with client successful for %s:%d", host, port) # Now establish SSL connection to upstream server def ssl_upstream_success(ssl_upstream_socket): """Callback when SSL handshake with upstream is successful""" logger.info("[MITM] SSL handshake with upstream %s:%d successful", host, port) # Set up bidirectional forwarding between SSL sockets client_closed = False upstream_closed = False def bidirectional_forward(): """Handle bidirectional forwarding in a single thread""" nonlocal client_closed, upstream_closed # Set sockets to non-blocking mode try: ssl_client_socket.setblocking(False) ssl_upstream_socket.setblocking(False) except: pass client_buffer = b"" upstream_buffer = b"" # HTTP request parsing buffers - no longer needed with simplified approach # client_http_buffer = b"" # upstream_http_buffer = b"" def parse_http_requests(data_buffer, direction): """Parse HTTP requests/responses from buffer and log them""" # Simple approach: just log the first line if it looks like HTTP try: lines = data_buffer.split(b"\r\n") if lines and lines[0]: first_line = lines[0].decode("utf-8", errors="ignore").strip() if first_line and " " in first_line: parts = first_line.split(" ") # Check if it's an HTTP response (starts with HTTP/) if first_line.startswith("HTTP/"): if len(parts) >= 2: status_code = parts[1] # Extract headers (simplified) headers = {} for line in lines[1:]: if b": " in line: header_line = line.decode("utf-8", errors="ignore") if ": " in header_line: key, value = header_line.split(": ", 1) headers[key] = value # Log the HTTP response log_response( status_code, f"https://{host}", headers, None, # Don't try to parse body True, # This is HTTPS ) # Check if it's an HTTP request (starts with method) elif len(parts) >= 2: method = parts[0] path = parts[1] # Check if it's a valid HTTP method if method in [ "GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS", "PATCH", ]: # Construct full URL if path.startswith("/"): url = f"https://{host}{path}" else: url = f"https://{host}/{path}" # Extract headers (simplified) headers = {} for line in lines[1:]: if b": " in line: header_line = line.decode("utf-8", errors="ignore") if ": " in header_line: key, value = header_line.split(": ", 1) headers[key] = value # Log the HTTP request log_request( None, method, url, headers, None, # Don't try to parse body True, # This is HTTPS ) except Exception as e: logger.debug(f"[MITM] Error parsing HTTP data ({direction}): {e}") while not client_closed and not upstream_closed: try: # Use select to check which socket has data readable, writable, _ = select.select( [ssl_client_socket, ssl_upstream_socket] if not client_closed and not upstream_closed else [], [ssl_client_socket, ssl_upstream_socket] if (client_buffer or upstream_buffer) else [], [], 0.1, ) # Handle readable sockets for sock in readable: try: if sock == ssl_client_socket and not client_closed: try: data = sock.recv(4096) if data: logger.debug( "[MITM] client->upstream received %d bytes", len(data) ) upstream_buffer += data # Parse HTTP requests from client data parse_http_requests(data, "client") else: logger.debug("[MITM] client connection closed gracefully") client_closed = True except ssl.SSLWantReadError: continue except ssl.SSLWantWriteError: continue except Exception as e: logger.error("[MITM] Client read error: %s", e) client_closed = True elif sock == ssl_upstream_socket and not upstream_closed: try: data = sock.recv(4096) if data: logger.debug( "[MITM] upstream->client received %d bytes", len(data) ) client_buffer += data # Parse HTTP requests from upstream data (responses) parse_http_requests(data, "upstream") else: logger.debug("[MITM] upstream connection closed gracefully") upstream_closed = True except ssl.SSLWantReadError: continue except ssl.SSLWantWriteError: continue except Exception as e: logger.error("[MITM] Upstream read error: %s", e) upstream_closed = True except Exception as e: logger.error("[MITM] Socket read exception: %s", e) if sock == ssl_client_socket: client_closed = True else: upstream_closed = True # Handle writable sockets for sock in writable: try: if sock == ssl_upstream_socket and upstream_buffer: try: sent = sock.send(upstream_buffer) if sent > 0: logger.debug("[MITM] client->upstream forwarded %d bytes", sent) upstream_buffer = upstream_buffer[sent:] except ssl.SSLWantReadError: continue except ssl.SSLWantWriteError: continue except Exception as e: logger.error("[MITM] Upstream write error: %s", e) upstream_closed = True elif sock == ssl_client_socket and client_buffer: try: sent = sock.send(client_buffer) if sent > 0: logger.debug("[MITM] upstream->client forwarded %d bytes", sent) client_buffer = client_buffer[sent:] except ssl.SSLWantReadError: continue except ssl.SSLWantWriteError: continue except Exception as e: logger.error("[MITM] Client write error: %s", e) upstream_closed = True except Exception as e: logger.error("[MITM] Socket write exception: %s", e) if sock == ssl_client_socket: client_closed = True else: upstream_closed = True except Exception as e: logger.error("[MITM] General forwarding exception: %s", e) break # Clean up both sockets try: if not client_closed: ssl_client_socket.shutdown(socket.SHUT_RDWR) ssl_client_socket.close() except: pass try: if not upstream_closed: ssl_upstream_socket.shutdown(socket.SHUT_RDWR) ssl_upstream_socket.close() except: pass # Start bidirectional forwarding in a single thread logger.info("[MITM] Starting bidirectional forwarding for %s:%d", host, port) forwarding_thread = threading.Thread(target=bidirectional_forward, daemon=True) forwarding_thread.start() def ssl_upstream_failure(ssl_upstream_socket): """Callback when SSL handshake with upstream fails""" logger.error("[MITM] SSL handshake with upstream %s:%d failed", host, port) try: ssl_client_socket.close() except Exception as e: logger.error("[MITM] Exception closing client socket after upstream failure: %s", e) def connect_upstream(): """Connect to upstream server and establish SSL""" try: logger.info("[MITM] Connecting to upstream %s:%d", host, port) upstream_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # Use blocking connection for simplicity upstream_socket.connect((host, port)) logger.info("[MITM] Connected to upstream %s:%d, starting SSL handshake", host, port) # Set up SSL connection to upstream server starttls( upstream_socket, host, self.application.ca_cert, self.application.ca_key, self.application.ca_key_pass, self.application.certs_folder, success=ssl_upstream_success, failure=ssl_upstream_failure, io_loop=tornado.ioloop.IOLoop.current(), do_handshake_on_connect=False, ssl_version=ssl.PROTOCOL_TLS, server_side=False, validate_cert=False, ) except Exception as e: logger.error("[MITM] Error creating upstream connection: %s", e) try: ssl_client_socket.close() except Exception as e2: logger.error( "[MITM] Exception closing client socket after upstream connect error: %s", e2 ) # Start the upstream connection process connect_upstream() except Exception as e: logger.error("[MITM] Exception in ssl_client_success callback: %s", e) def ssl_client_failure(ssl_client_socket): logger.error("[MITM] SSL handshake with client failed for %s:%d", host, port) try: client_stream.close() except Exception as e: logger.error("[MITM] Exception closing client stream after handshake failure: %s", e) logger.info("[MITM] Starting SSL handshake with client for %s:%d", host, port) starttls( client_stream.socket, host, self.application.ca_cert, self.application.ca_key, self.application.ca_key_pass, self.application.certs_folder, success=ssl_client_success, failure=ssl_client_failure, io_loop=tornado.ioloop.IOLoop.current(), do_handshake_on_connect=False, ssl_version=ssl.PROTOCOL_TLS, server_side=True, ) except Exception as e: logger.error("[MITM] Error in connect method: %s", e) self._finished = True try: self.request.connection.stream.write(b"HTTP/1.1 502 Bad Gateway\r\n\r\n") except Exception as send_error: logger.error("[MITM] Error sending 502 response: %s", send_error) try: self.request.connection.stream.close() except Exception as close_error: logger.error("[MITM] Error closing connection after error: %s", close_error)
[docs] class CustomWebSocketHandler(tornado.websocket.WebSocketHandler): """Class is used for handling websocket traffic. * Object of this class replaces the main request handler for a request with header => "Upgrade: websocket" * wss:// - CONNECT request is handled by main handler """
[docs] def upstream_connect(self, io_loop=None, callback=None): """Custom alternative to tornado.websocket.websocket_connect. .. note:: Returns a future instance. :param io_loop: :type io_loop: :param callback: :type callback: :return: :rtype: """ # io_loop is needed or it won't work with Tornado. if io_loop is None: io_loop = tornado.ioloop.IOLoop.current() # During secure communication, we get relative URI, so make them absolute if self.request.uri.startswith(self.request.protocol, 0): # Normal Proxy Request. self.request.url = self.request.uri # Transparent Proxy Request else: self.request.url = "{!s}://{!s}{!s}".format(self.request.protocol, self.request.host, self.request.uri) self.request.url = self.request.url.replace("http", "ws", 1) # Log WebSocket connection is_https = self.request.protocol == "https" log_request(self.request, "WEBSOCKET", self.request.url, self.request.headers, None, is_https) # Have to add cookies and stuff request_headers = tornado.httputil.HTTPHeaders() for name, value in list(self.request.headers.items()): if name not in ProxyHandler.restricted_request_headers: request_headers.add(name, value) # Build a custom request request = tornado.httpclient.HTTPRequest( url=self.request.url, headers=request_headers, proxy_host=self.application.outbound_ip, proxy_port=self.application.outbound_port, proxy_username=self.application.outbound_username, proxy_password=self.application.outbound_password, ) self.upstream_connection = CustomWebSocketClientConnection(io_loop, request) if callback is not None: io_loop.add_future(self.upstream_connection.connect_future, callback) return self.upstream_connection.connect_future
def _execute(self, transforms, *args, **kwargs): """Overriding of a method of WebSocketHandler :param transforms: :type transforms: :param args: :type args: :param kwargs: :type kwargs: :return: :rtype: """ def start_tunnel(future): """A callback which is called when connection to url is successful.""" # We need upstream to write further messages self.upstream = future.result() # HTTPRequest needed for caching self.handshake_request = self.upstream_connection.request # Needed for websocket data & compliance with cache_handler stuff self.handshake_request.response_buffer = "" # Tiny hack to protect caching (according to websocket standards) self.handshake_request.version = "HTTP/1.1" # XXX: I dont know why a None is coming self.handshake_request.body = self.handshake_request.body or "" # The regular procedures are to be done tornado.websocket.WebSocketHandler._execute(self, transforms, *args, **kwargs) # We try to connect to provided URL & then we proceed with connection on client side. self.upstream = self.upstream_connect(callback=start_tunnel)
[docs] def store_upstream_data(self, message): """Save websocket data sent from client to server. i.e add it to HTTPRequest.response_buffer with direction (>>) :param message: Message to be stored :type message: `str` :return: None :rtype: None """ try: # Cannot write binary content as a string, so catch it self.handshake_request.response_buffer += ">>> {}\r\n".format(message) except TypeError: self.handshake_request.response_buffer += ">>> May be binary\r\n"
[docs] def store_downstream_data(self, message): """Save websocket data sent from client to server. i.e add it to HTTPRequest.response_buffer with direction (<<) :param message: Downstream data :type message: `str` :return: None :rtype: None """ try: # Cannot write binary content as a string, so catch it. self.handshake_request.response_buffer += "<<< {}\r\n".format(message) except TypeError: self.handshake_request.response_buffer += "<<< May be binary\r\n"
[docs] def on_message(self, message): """Everytime a message is received from client side, this instance method is called. :param message: Message to write or store :type message: `str` :return: None :rtype: None """ self.upstream.write_message(message) # The obtained message is written to upstream. self.store_upstream_data(message) # The following check ensures that if a callback is added for reading message from upstream, another one is not # added. if not self.upstream.read_future: # A callback is added to read the data when upstream responds. self.upstream.read_message(callback=self.on_response)
[docs] def on_response(self, message): """A callback when a message is recieved from upstream. :param message: :type message: :return: :rtype: """ # The following check ensures that if a callback is added for reading message from upstream, another one is not # added if not self.upstream.read_future: self.upstream.read_message(callback=self.on_response) if self.ws_connection: # Check if connection still exists. if message.result(): # Check if it is not NULL (indirect checking of upstream connection). self.write_message(message.result()) # Write obtained message to client. self.store_downstream_data(message.result()) else: self.close()
[docs] def on_close(self): """Called when websocket is closed. So handshake request-response pair along with websocket data as response body is saved :return: None :rtype: None """ # Required for cache_handler self.handshake_response = tornado.httpclient.HTTPResponse( self.handshake_request, self.upstream_connection.code, headers=self.upstream_connection.headers, request_time=0, ) # Procedure for dumping a tornado request-response self.cache_handler = CacheHandler( self.application.cache_dir, self.handshake_request, self.application.cookie_regex, self.application.cookie_blacklist, ) self.cached_response = self.cache_handler.load() self.cache_handler.dump(self.handshake_response)
[docs] class CustomWebSocketClientConnection(tornado.websocket.WebSocketClientConnection): def _handle_1xx(self, code): """Had to extract response code, so it is necessary to override. :param code: status code :type code: `int` :return: None :rtype: None """ self.code = code super(CustomWebSocketClientConnection, self)._handle_1xx(code)