#!/usr/bin/env python3
# SPDX-License-Identifier: BSD-3-Clause
#
# map-color-index - Map from one palette index to another
#
# Often there are PNG files that use an index that should not be used. For
# example, in the Doom palette index 255 is used for transparence in some
# legacy engines, so it can be helpful to avoid it. Color 255 can be mapped to
# 133 since they are visually similar.
#
# This programs strives to be the following:
#   1) Be fast in the common case of the PNG not needing to be changed.
#   2) Be as safe as possible by making a few extraneous changes as possible.
#      All chunks other than the pixels (IDAT) are preserved.
#   3) Be as safe as possible by making only modifying relevant PNGs. For
#      additonal safety a palette can be passed (-p option) which must be
#      matched exactly for any change to be made.
#
# Normally this script is invoked by "make":
#   make fix-legacy-transparency-pngs

# Imports

import argparse
import os
import png
import sys

# Globals. Alphabetical.

# Command line arguments.
args = {}

# The palette in byte form (binary read of args.palette).
palette_bytes = None

# The number of RGB entries in the args.palette.
palette_size = 0

# The number of PNGs that were actually changed.
pngs_changed_count = 0

# The total number of PNGs that were examined.
pngs_examined_count = 0

# Main entry point.
def main():
    parse_args()
    parse_palette()
    process_dir()
    summarize()

# Parse the command line arguments and store the result in 'args'.
def parse_args():
    global args

    parser = argparse.ArgumentParser(
        description="Map from one palette index to another.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # The following is sorted by long argument.

    parser.add_argument(
        "-k",
        "--keep",
        action="store_true",
        help="Keep temporary files.")
    parser.add_argument(
        "-d",
        "--debug",
        action="store_true",
        help="Debug (lots of) output. Implies verbose.")
    parser.add_argument(
        "-p",
        "--palette",
        help="If specified then only PNGs with precisely this palette will be " +
             "altered. This is tightly packed list of RGB three byte pairs, " +
             "just like the PLTE chunk. \"lumps/playpal/playpal-base.lmp\" " +
             "is a common value.")
    parser.add_argument(
        "-v",
        "--verbose",
        action="store_true",
        help="Verbose output. Typically one line per PNG.")
    parser.add_argument(
        "dir",
        metavar="DIR",
        help="Directory to process recursively.")
    parser.add_argument(
        "from_index",
        metavar="FROM-INDEX",
        type=int,
        help="Change from this index.")
    parser.add_argument(
        "to_index",
        metavar="TO-INDEX",
        type=int,
        help="Change to this index.")

    args = parser.parse_args()

    # Additonal custom argument parsing and checking.

    # TODO: Is writing to "args" ok?
    # --debug implies --verbose.
    if args.debug:
        args.verbose = True

    return args

# Reads args.palette, which a file containing RGBRGBRGB ...
def parse_palette():
    global palette_bytes
    global palette_size

    if not args.palette:
        return
    with open(args.palette, "rb") as handle:
        palette_bytes = handle.read()
        if len(palette_bytes) % 3:
            print("The length of", args.palette, "is not a multiple of three.",
                file=sys.stderr)
            sys.exit(1)
        if len(palette_bytes) > 768:
            print("The length of", args.palette, "exceeds 768.", file=sys.stderr)
            sys.exit(1)
    palette_size = len(palette_bytes) // 3
    if args.debug:
        print("palette_bytes", palette_bytes)
        print("palette_size", palette_size)

# Process a directory recursively for PNG files.
def process_dir():
    global pngs_changed_count
    global pngs_examined_count

    pngs_changed = 0
    for dirpath, dirnames, filenames in os.walk(args.dir):
        for png_base in filenames:
            if not png_base.lower().endswith(".png"):
                continue
            png_path = os.path.join(dirpath, png_base)
            pngs_examined_count += 1
            if process_png(png_path):
                pngs_changed_count += 1

# Process a PNG file in place.
def process_png(png_path):
    # This method works by first seeing if the PNG needs to be modified, and if
    # so a temporary PNG is buit with from-index replaced with to-index. The
    # IDAT (pixels) are then taken from that temporary PNG and combined with
    # the chunks from the original PNG in order to produce a PNG where only the
    # IDAT is modified.
    #
    # Prefix naming conventions used in this method:
    #   old_ - Taken from png_path in its original form.
    #   tmp_ - A temporary file that has the correct new IDAT, but otherwise
    #          unwanted chunks.
    #   new_ - The new png_path.
    # Some variables that lack a prefix apply to more than one of the above.

    # If verbose print the last 22 characters of the relative png_path followed
    # by two spaces for a total width of 24. If not debug then end=""
    # so that the resolution is on the same line with one line per PNG.
    if args.verbose:
        png_relpath = os.path.relpath(png_path, args.dir)
        print("{0:22s}  ".format(png_relpath[-22:]), end=("\n" if args.debug else ""))

    # Ignore temporary files.
    if png_path.endswith("-tmp.png"):
        if args.verbose:
            print("temporary file")
        return False

    # Read the existing old png_path as a flat array of pixels as well as some
    # other information (kwargs). Note that png.Reader only allows images to be
    # read once in a single pass, so it's sometimes necessary to open the same
    # PNG multiple times.
    old_reader = png.Reader(png_path)
    old_flat = old_reader.read_flat()

    # Get the existing pixels. The kwargs is a map that describes the mode and
    # style of the PNG, so it can be replicated.
    old_pixels = old_flat[2]
    kwargs = old_flat[3]

    if args.debug:
        print("old_flat", old_flat)

    # Before doing any other processing return quickly if the image is not
    # reasonable.

    # If it does not have a palette then it is not even indexed.
    if not "palette" in kwargs:
        if args.verbose:
            print("no palette")
        return False

    # Is there a from-index to change?
    if not args.from_index in old_pixels:
        if args.verbose:
            print("from-index not found")
        return False

    if args.palette:
        # Check that the palette has the correct size first since it's fast.
        if len(kwargs["palette"]) != palette_size:
            if args.verbose:
                print("palette incorrect size")
            return False

        # Check that PLTE exactly matches the palette specified.
        old_reader = png.Reader(png_path)
        old_plte = old_reader.chunk("PLTE")[1]
        if old_plte != palette_bytes:
            if args.verbose:
                print("palette does not match")
            return False

    # Open png_path again to get the chunks.
    old_reader = png.Reader(png_path)
    old_chunks = old_reader.chunks()

    # Build a new list of pixels with the index replaced.
    if args.debug:
        print("old_pixels", old_pixels)
    new_pixels = [args.to_index if p == args.from_index else p for p in old_pixels]
    if args.debug:
        print("new_pixels", new_pixels)

    # Remove keys from kwargs that are known to cause trouble with writing.
    kwargs.pop("background", None)

    # Create a temporary PNG with the desired pixels, but not the desired
    # chunks. Passing kwargs assures that the temporary PNG has the same mode
    # as the original one.
    tmp_writer = png.Writer(**kwargs)
    tmp_path = png_path[:-4] + "-tmp.png"
    with open(tmp_path, "wb") as tmp_hand:
        tmp_writer.write_array(tmp_hand, new_pixels)

    # TODO: It would be nice if there was access to the internal API that
    # generates the IDAT without creating temporary files.
    tmp_reader = png.Reader(tmp_path)
    tmp_idat = tmp_reader.chunk("IDAT")
    if not args.keep:
        os.remove(tmp_path)
    if args.debug:
        print("tmp_idat", tmp_idat)

    # Create the new image with the new pixels, but the old chunks.
    new_chunks_array = [tmp_idat if chunk[0] == "IDAT" else chunk
                        for chunk in old_chunks]
    if args.debug:
        print("new_chunks_array", new_chunks_array)

    # Overwrite the existing png_path with the new chunks.
    with open(png_path, "wb") as new_hand:
        png.write_chunks(new_hand, new_chunks_array)

    if args.verbose:
        print("changed")
    if args.debug:
        print()

    return True

# Summarize what happened, and then exit with the appropriate exit code.
def summarize():
    print()
    print("Processed directory", args.dir)
    print("Examined", pngs_examined_count, "PNGs.")
    print("Changed ", pngs_changed_count , "PNGs.")

# So that this script may be accessed as a module.
if __name__ == "__main__":
    main()
