Application centric Cloudwatch logging in AWS Lambda functions with python3+ runtimes

Abstract

Application centric logging is a system where there are one or more components all directing thier log entries to a single logger. In the AWS context, this could mean an application composed of one or more AWS Lamba functions each logging to a single application-wide AWS CloudWatch log stream. By “single”, I mean single to the application, not to each function.

In lambda functions with python runtimes, the default mode of logging is one log stream per lambda. We can do this via the print() function or the logging module. But there are sometimes situations where mulltiple lambdas are co-operating to solve a larger problem, where the co-operation is synchronous, and where it would be of value to be able to view a unified log stream for events accross multiple lambdas.

How do we achieve this, with a simple client interface? In this post, I present a solution.

A Solution

Include in the packaging of your AWS Lambda function, the following python 3 script with filename “custom_logging.py”.

################################################################################
##  CustomLogging class
################################################################################
import boto3
import time, sys
import logging 

def coerceLoggingType( logType):
  if (logType is None) or (logType == ''):
    logType = logging.INFO
  elif isinstance( logType, str):
    logType = getattr( logging, logType.upper(), logging.INFO)
  return logType

global stockFormats
global defaultFormat
global levelNames

defaultFormat = '#standard'

stockFormats = {
  '#standard': '{level}: {func}: {caller}: ',
  '#short'   : '{level}: {func}: ',
  '#simple'  : '{func}: '} 

levelNames = {
  logging.DEBUG   : 'DEBUG',
  logging.INFO    : 'INFO',
  logging.WARNING : 'WARNING',
  logging.ERROR   : 'ERROR',
  logging.CRITICAL: 'CRITICAL'}

botoLoggers = ['boto', 'boto3', 'botocore', 'urllib3']

def _json_formatter( obj):
  """Formatter for unserialisable values."""
  return str( obj)

class JsonFormatter( logging.Formatter):
  """AWS Lambda Logging formatter.
  Formats the log message as a JSON encoded string.  If the message is a
  dict it will be used directly.  If the message can be parsed as JSON, then
  the parse d value is used in the output record.
  """
  def __init__( self, **kwargs):
    super( JsonFormatter, self).__init__()
    self.format_dict = {
      'timestamp': '%(asctime)s',
      'level': '%(levelname)s',
      'location': '%(name)s.%(funcName)s:%(lineno)d'}
    self.format_dict.update(kwargs)
    self.default_json_formatter = kwargs.pop( 'json_default', _json_formatter)

  def format( self, record):
    record_dict = record.__dict__.copy()
    record_dict['asctime'] = self.formatTime( record)
    log_dict = {
      k: v % record_dict
      for k, v in self.format_dict.items() if v}
    if isinstance( record_dict['msg'], dict):
      log_dict['message'] = record_dict['msg']
    else:
      log_dict['message'] = record.getMessage()
    # Attempt to decode the message as JSON, if so, merge it with the
    # overall message for clarity.
    try:
      log_dict['message'] = json.loads( log_dict['message'])
    except ( TypeError, ValueError):
      pass
    if record.exc_info:
      # Cache the traceback text to avoid converting it multiple times
      # (it's constant anyway)
      # from logging.Formatter:format
      if not record.exc_text:
        record.exc_text = self.formatException( record.exc_info)
    if record.exc_text:
      log_dict['exception'] = record.exc_text
    json_record = json.dumps( log_dict, default=self.default_json_formatter)
    if hasattr( json_record, 'decode'):  # pragma: no cover
      json_record = json_record.decode( 'utf-8')
    return json_record

def setupCanonicalLogLevels( logger, level, fmt, formatter_cls=JsonFormatter, boto_level=None, **kwargs):
  if not isinstance( logger, logging.Logger):
    raise Exception( 'Wrong class of logger passed to setupCanonicalLogLevels().')
  if logger is not None:
    logger.setLevel( level)
  logging.root.setLevel( level)
  if fmt is not None:
    logging.basicConfig( format=fmt)
    fmtObj = logging.Formatter( fmt)
  else:
    fmtObj = None
  for handler in logging.root.handlers:
    try:
      if fmtObj is not None:
        handler.setFormatter( fmtObj)
      elif formatter_cls is not None:
        handler.setFormatter( formatter_cls( **kwargs))
    except:
      pass
  if boto_level is None:
    boto_level = level
  for loggerId in botoLoggers:
    try:
      logging.getLogger( loggerId).setLevel( boto_level)
    except:
      pass
 
 
class NullLogger():
  def __init__( self):
    pass
 
  def purge( self):
    pass
 
  def log( self, level, msg, withPurge=False):
    pass
 
  def debug( self, msg, withPurge=False):
    pass
 
  def info( self, msg, withPurge=False):
    pass
 
  def warning( self, msg, withPurge=False):
    pass
 
  def critical( self, msg, withPurge=False):
    pass
 
  def error( self, msg, withPurge=False):
    pass
 
  def exception( self, msg, withPurge=False):
    pass
 
  def classCode( self):
    return '#null'
 
  def isPurgeable( self):
    return False
 
class PrintLogger():
  def __init__( self, threshold):
    self.threshold = threshold
 
  def purge( self):
    pass
 
  def log( self, level, msg, withPurge=False):
    if level >= self.threshold:
      print( msg)
 
  def debug( self, msg, withPurge=False):
    self.log( logging.DEBUG, msg, False)
 
  def info( self, msg, withPurge=False):
    self.log( logging.INFO, msg, False)
 
  def warning( self, msg, withPurge=False):
    self.log( logging.WARNING, msg, False)
 
  def critical( self, msg, withPurge=False):
    self.log( logging.CRITICAL, msg, False)
 
  def error( self, msg, withPurge=False):
    self.log( logging.ERROR, msg, False)
 
  def exception( self, msg, withPurge=False):
    self.log( logging.ERROR, msg, False)
 
  def classCode( self):
    return '#print'
 
  def isPurgeable( self):
    return False
 
def createPolymorphicLogger( logClass, logGroup, logStream, logLevel = logging.INFO, functionName = None, msgFormat = None):
  if logClass == 'cloud-watch':
    return CustomLogging( logGroup, logStream, logLevel, functionName, msgFormat)
  elif logClass == '#print':
    return PrintLogger( logLevel)
  elif (logClass == '#null') or (logClass is None):
    return NullLogger()
  elif isinstance( logClass, dict) and ('logging' in logClass):
    loggingParams    = logClass.get( 'logging', {})
    cloudWatchParams = loggingParams.get( 'cloud-watch', {})
    if msgFormat is None:
      msgFormat = '#mini'
    actualLogClass  = loggingParams.get( 'class')
    logGroup     = cloudWatchParams.get( 'group'   , logGroup)
    logStream    = cloudWatchParams.get( 'stream'  , logStream)
    logLevel     =    loggingParams.get( 'level'   , logLevel)
    functionName = cloudWatchParams.get( 'function', functionName)
    msgFormat    = cloudWatchParams.get( 'format'  , msgFormat)
    return createLogger( actualLogClass, logGroup, logStream, logLevel, functionName, msgFormat)
  elif isinstance( logClass, dict) and ('class' in logClass):
    canonicalLogClassRecord = {'logging': logClass}
    return createLogger( canonicalLogClassRecord, logGroup, logStream, logLevel, functionName, msgFormat)
  elif logClass == '#standard-logger':
    logger = logging.getLogger( name=logStream)
    if msgFormat is None:
      msgFormat = '[%(levelname)s] %(message)s'
    setupCanonicalLogLevels( logger, logLevel, msgFormat, JsonFormatter, logging.ERROR)
    return logger
  else:
    raise Exception( f'Unrecognised log class {logClass}')
 
def getClassCode( logger):
  code = '#null'
  if isinstance( logger, logging.Logger):
    code = '#standard-logger'
  elif logger is not None:
    try:
      code = logger.classCode()
    except:
      code = '#unrecognised'
  return code
 
def isLoggerPurgeable( logger):
  result = False
  if (not isinstance( logger, logging.Logger)) and (logger is not None):
    try:
      result = logger.isPurgeable()
    except:
      pass
  return result
 
class CustomLogging:
  def __init__( self, logGroup, logStream, logLevel = logging.INFO, functionName = None, msgFormat = None):
    """ logGroup is the name of the CloudWatch log group. If none, the messages passes to print.
        logStream is the name of the stream. It is required. It is a string. There is no embedded date processing.
        logLevel is one of the logging level constants or its string equivalent. Posts below this level will be swallowed.
        functionName is the name of the lambda.
        msgFormat determines the logged message prefix. It is either a format string, a label or a function.
          If it is a format string, the following substitution identifiers:
            {level}  The message log level.
            {func}   The passed functionName
            {caller} The python caller function name
          If it is a label, is one of:
            #standard   - This is the default.
            #short
            #simple
            #mini
          If it is a function (or callable object), it must be a function that returns a prefix string with
            the following input parameters in order:
              level           - passed message level
              functionName  - constructed function name
              caller          - invoker caller name
              logMsg          - passed message
             
        EXAMPLE USAGE 1:
          import custom_logging, logging
         
          logger = CustomLogging( '/aws/ec2/prod/odin', '2022-06-29-MLC_DAILY-143', logging.INFO, 'CoolLambdaFunc', '#mini')
          logger.info( 'Hello friend! This is an info')
          logger.error( 'I broke it!')
          logger.purge()
       
        
        EXAMPLE USAGE 2:
          import custom_logging, logging
         
          logger = CustomLogging( None, None, logging.DEBUG, 'CoolLambdaFunc', '#mini')
          logger.info( 'This is the same as print')
      
        
        EXAMPLE USAGE 3:
          import custom_logging, logging
         
          logger = CustomLogging( None, None, logging.WARNING, 'CoolLambdaFunc', '{caller} | {level} !! {func}: ')
          
       
        
        EXAMPLE USAGE 3:
          import custom_logging, logging
         
          def colourMePink( level, functionName, caller, logMsg):
            if level == logging.DEBUG:
              prefix = '{level}: {func}: {caller}: '.format( level = sLevel, func = functionName, caller = caller)
            elif  level == logging.INFO:
              prefix = ''
            else:
              prefix = '{level}: '.format( level = sLevel)
            return prefix
         
          logger = CustomLogging( None, None, logging.INFO, None, colourMePink)
          
    """
    self.logs           = boto3.client( 'logs', region_name='ap-southeast-2')
    self.logEvents      = []
    self.functionName = functionName
    if self.functionName is None:
      self.functionName = ''
    self.logGroup       = logGroup
    self.logStream      = logStream
    self.msgFormat = msgFormat
    if self.msgFormat is None:
      self.msgFormat = defaultFormat
    if isinstance( self.msgFormat, str) and (self.msgFormat in stockFormats):
      self.msgFormat = stockFormats[self.msgFormat]
    elif self.msgFormat == '#mini':
      self.msgFormat = self._miniFormat
    self.logLevel       = coerceLoggingType( logLevel)
    self.sequenceToken  = None
    self.sequenceTokenIsValid = False
    self.maxEventsInBuffer = 20
    self.maxBufferAgeMs = 60000 # 1 minute.
 
  def _formatMessage( self, caller, logType, logMsg):
    prefix = ''
    if caller is None:
      try:
        caller = sys._getframe(3).f_code.co_name
      except:
        caller = ''
    sLevel = levelNames.get( logType, str( logType))
    if isinstance( self.msgFormat, str):
      prefix = self.msgFormat.format( level = sLevel, func = self.functionName, caller = caller)
    elif callable( self.msgFormat):
      prefix = self.msgFormat( logType, self.functionName, caller, logMsg)
    return prefix + str( logMsg)
 
  def _miniFormat( self, level, functionName, caller, logMsg):
    prefix = ''
    if level >= logging.WARNING:
      prefix = levelNames[ level] + ': '
    if functionName != '':
      prefix = prefix + functionName + ': '
    return prefix
 
  def _getSequenceToken( self):
    self.sequenceToken = None
    self.sequenceTokenIsValid = True
    try:
      response = self.logs.describe_log_streams( logGroupName=self.logGroup, logStreamNamePrefix=self.logStream)
    except self.logs.exceptions.ResourceNotFoundException:
      return 'group-not-found'
    try:
      if 'uploadSequenceToken' in response['logStreams'][0]:
        self.sequenceToken = response['logStreams'][0]['uploadSequenceToken']
      if self.sequenceToken == '':
        self.sequenceToken = None
    except:
      pass
    if self.sequenceToken is None:
      return 'stream-not-found-or-virgin-stream'
    else:
      return None
 
  def put( self, logMsg, logType = logging.INFO, withPurge=False, callFunc = None):
    logType = coerceLoggingType( logType)
    if self.logLevel <= logType:
      if self.logGroup is not None:
        timestamp = int( round( time.time() * 1000))
        message = self._formatMessage( callFunc, logType, logMsg)
        logEvent = {'timestamp': timestamp, 'message': message}
        if self.logLevel == logging.DEBUG:
         print( message)
        self.logEvents.append( logEvent)
        count = len( self.logEvents)
        if withPurge or \
           (count >= self.maxEventsInBuffer) or \
           ((count >= 1) and ((timestamp - self.logEvents[0]['timestamp']) >= self.maxBufferAgeMs)):
          self.purge()
      else:
        print( logMsg)
 
  def classCode( self):
    return 'cloud-watch'
 
  def _primitive_put_log_events( self):
    event_log = {
      'logGroupName' : self.logGroup,
      'logStreamName': self.logStream,
      'logEvents'    : self.logEvents}
    if self.sequenceToken is not None:
      event_log['sequenceToken'] = self.sequenceToken
    try:
      response = self.logs.put_log_events( **event_log)
      self.sequenceToken = response.get( 'nextSequenceToken')
      self.sequenceTokenIsValid = True
      result = None
    except self.logs.exceptions.ResourceAlreadyExistsException:
      self.sequenceTokenIsValid = False
      result = None
    except self.logs.exceptions.DataAlreadyAcceptedException:
      self.sequenceTokenIsValid = False
      result = None
    except self.logs.exceptions.InvalidSequenceTokenException:
      self.sequenceTokenIsValid = False
      result = 'invalid-sequence-token'
    except self.logs.exceptions.ResourceNotFoundException:
      self.sequenceTokenIsValid = True
      self.sequenceToken = None
      result = 'stream-not-found'
    return result
 
  def _primitive_create_log_stream( self):
    self.sequenceTokenIsValid = True
    self.sequenceToken = None
    try:
      self.logs.create_log_stream( logGroupName=self.logGroup, logStreamName=self.logStream)
      result = None
    except self.logs.exceptions.ResourceAlreadyExistsException:
      self.sequenceTokenIsValid = False
      result = None
    except self.logs.exceptions.ResourceNotFoundException:
      result = 'group-not-found'
    return result
 
  def _primitive_create_log_group( self):
   self.sequenceTokenIsValid = True
    self.sequenceToken = None
    try:
      self.logs.create_log_group( logGroupName=self.logGroup)
    except self.logs.exceptions.ResourceAlreadyExistsException:
      pass
 
  def _robust_put_log_events( self):
    status = 'hungry'
    for tryCount in range( 100):
      if status == 'group-not-found':
        self._primitive_create_log_group()
        status = 'stream-not-found'
      elif status == 'stream-not-found':
        status = self._primitive_create_log_stream()
        if status is None:
          status = 'hungry'
      elif status == 'invalid-sequence-token':
        getSequenceResult = self._getSequenceToken()
        # getSequenceResult == 'group-not-found' | 'stream-not-found-or-virgin-stream' | None
        if getSequenceResult == 'group-not-found':
          status = 'group-not-found'
        elif getSequenceResult == 'stream-not-found-or-virgin-stream':
          status = 'stream-not-found'
        else:
          status = 'ready'
      elif status == 'hungry':
        if not self.sequenceTokenIsValid:
          status = 'invalid-sequence-token'
        else:
          status = 'ready'
      elif status == 'ready':
        status = self._primitive_put_log_events()
        if status is None:
          status = 'done'
      if status == 'done':
        break
    if status != 'done':
      raise Exception( 'Failed to post to CloudWatch Logs.')
 
  def purge( self):
    if len( self.logEvents) > 0:
      try:
        self._robust_put_log_events()
      except Exception as ex:
        print( self.logEvents)
        print( ex)
      self.logEvents = []
 
  def log( self, level, msg, withPurge=False):
    self.put( msg, level, withPurge, None)
 
  def debug( self, msg, withPurge=False):
    self.put( msg, logging.DEBUG, withPurge, None)
 
  def info( self, msg, withPurge=False):
    self.put( msg, logging.INFO, withPurge, None)
 
  def warning( self, msg, withPurge=False):
    self.put( msg, logging.WARNING, withPurge, None)
 
  def error( self, msg, withPurge=False):
    self.put( msg, logging.ERROR, withPurge, None)
 
  def critical( self, msg, callFunc = None):
    self.put( msg, logging.CRITICAL, True, callFunc)
 
  def exception( self, msg, withPurge=True):
    self.log( logging.ERROR, msg, True)
 
  def isPurgeable( self):
    return True
 
  def __del__( self):
    try:
      self.purge()
    except:
      pass

How to use

Import custom_logging. In your lambda code, where you need application-centric logging, invoke the factory method createPolymorphicLogger() to create a logger. Then send all your application-centric log events to this logger, instead of print().

The logger is going to have the following public methods.

  • purge()
  • log( level, msg, withPurge=False)
  • debug/info/warning/critical/error/exception( msg, withPurge=False)

Use the log() method to log a string message. ‘level’ is one of the usual logging levels: DEBUG, INFO etc. For performance reasons, messages are buffered before actually sending to CloudWatch. The buffer is purged when either: (A) the buffer gets too long; or (B) the buffer ages out (1 minute); or (C) the withPurge parameter is explicitly set to True. Invoking the purge() method or releasing the custom logger class instance will also do it.

The debug() etc methods are short hand for the log() method when the level is fixed.

How to configure it

Refer to the inline comments.

This entry was posted in Python and tagged , . Bookmark the permalink.

Leave a Reply

Your email address will not be published. Required fields are marked *

Comments Protected by WP-SpamShield Spam Plugin