4@contact: sherlockliao01@gmail.com
7""" AutoAugment, RandAugment, and AugMix for PyTorch
8This code implements the searched ImageNet policies with various tweaks and improvements and
9does not include any of the search code.
10AA and RA Implementation adapted from:
11 https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
13 https://github.com/google-research/augmix
15 AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501
16 Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172
17 RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
18 AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781
19Hacked together by Ross Wightman
27from PIL
import Image, ImageOps, ImageEnhance
29_PIL_VER = tuple([int(x)
for x
in PIL.__version__.split(
'.')[:2]])
31_FILL = (128, 128, 128)
37_HPARAMS_DEFAULT = dict(
42_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
46 interpolation = kwargs.pop(
'resample', Image.BILINEAR)
47 if isinstance(interpolation, (list, tuple)):
48 return random.choice(interpolation)
54 if 'fillcolor' in kwargs
and _PIL_VER < (5, 0):
55 kwargs.pop(
'fillcolor')
61 return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
66 return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
70 pixels = pct * img.size[0]
72 return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
76 pixels = pct * img.size[1]
78 return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
83 return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
88 return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
93 if _PIL_VER >= (5, 2):
94 return img.rotate(degrees, **kwargs)
95 elif _PIL_VER >= (5, 0):
98 rotn_center = (w / 2.0, h / 2.0)
99 angle = -math.radians(degrees)
101 round(math.cos(angle), 15),
102 round(math.sin(angle), 15),
104 round(-math.sin(angle), 15),
105 round(math.cos(angle), 15),
109 def transform(x, y, matrix):
110 (a, b, c, d, e, f) = matrix
111 return a * x + b * y + c, d * x + e * y + f
113 matrix[2], matrix[5] = transform(
114 -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
116 matrix[2] += rotn_center[0]
117 matrix[5] += rotn_center[1]
118 return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
120 return img.rotate(degrees, resample=kwargs[
'resample'])
124 return ImageOps.autocontrast(img)
128 return ImageOps.invert(img)
132 return ImageOps.equalize(img)
136 return ImageOps.solarize(img, thresh)
143 lut.append(min(255, i + add))
146 if img.mode
in (
"L",
"RGB"):
147 if img.mode ==
"RGB" and len(lut) == 256:
148 lut = lut + lut + lut
149 return img.point(lut)
155 if bits_to_keep >= 8:
157 return ImageOps.posterize(img, bits_to_keep)
161 return ImageEnhance.Contrast(img).enhance(factor)
165 return ImageEnhance.Color(img).enhance(factor)
169 return ImageEnhance.Brightness(img).enhance(factor)
173 return ImageEnhance.Sharpness(img).enhance(factor)
177 """With 50% prob, negate the value"""
178 return -v
if random.random() > 0.5
else v
183 level = (level / _MAX_LEVEL) * 30.
190 return (level / _MAX_LEVEL) * 1.8 + 0.1,
196 level = (level / _MAX_LEVEL) * .9
203 level = (level / _MAX_LEVEL) * 0.3
209 translate_const = hparams[
'translate_const']
210 level = (level / _MAX_LEVEL) * float(translate_const)
217 translate_pct = hparams.get(
'translate_pct', 0.45)
218 level = (level / _MAX_LEVEL) * translate_pct
227 return int((level / _MAX_LEVEL) * 4),
241 return int((level / _MAX_LEVEL) * 4) + 4,
247 return int((level / _MAX_LEVEL) * 256),
258 return int((level / _MAX_LEVEL) * 110),
262 'AutoContrast':
None,
265 'Rotate': _rotate_level_to_arg,
267 'Posterize': _posterize_level_to_arg,
268 'PosterizeIncreasing': _posterize_increasing_level_to_arg,
269 'PosterizeOriginal': _posterize_original_level_to_arg,
270 'Solarize': _solarize_level_to_arg,
271 'SolarizeIncreasing': _solarize_increasing_level_to_arg,
272 'SolarizeAdd': _solarize_add_level_to_arg,
273 'Color': _enhance_level_to_arg,
274 'ColorIncreasing': _enhance_increasing_level_to_arg,
275 'Contrast': _enhance_level_to_arg,
276 'ContrastIncreasing': _enhance_increasing_level_to_arg,
277 'Brightness': _enhance_level_to_arg,
278 'BrightnessIncreasing': _enhance_increasing_level_to_arg,
279 'Sharpness': _enhance_level_to_arg,
280 'SharpnessIncreasing': _enhance_increasing_level_to_arg,
281 'ShearX': _shear_level_to_arg,
282 'ShearY': _shear_level_to_arg,
283 'TranslateX': _translate_abs_level_to_arg,
284 'TranslateY': _translate_abs_level_to_arg,
285 'TranslateXRel': _translate_rel_level_to_arg,
286 'TranslateYRel': _translate_rel_level_to_arg,
290 'AutoContrast': auto_contrast,
291 'Equalize': equalize,
294 'Posterize': posterize,
295 'PosterizeIncreasing': posterize,
296 'PosterizeOriginal': posterize,
297 'Solarize': solarize,
298 'SolarizeIncreasing': solarize,
299 'SolarizeAdd': solarize_add,
301 'ColorIncreasing': color,
302 'Contrast': contrast,
303 'ContrastIncreasing': contrast,
304 'Brightness': brightness,
305 'BrightnessIncreasing': brightness,
306 'Sharpness': sharpness,
307 'SharpnessIncreasing': sharpness,
310 'TranslateX': translate_x_abs,
311 'TranslateY': translate_y_abs,
312 'TranslateXRel': translate_x_rel,
313 'TranslateYRel': translate_y_rel,
319 def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
320 hparams = hparams
or _HPARAMS_DEFAULT
327 fillcolor=hparams[
'img_mean']
if 'img_mean' in hparams
else _FILL,
328 resample=hparams[
'interpolation']
if 'interpolation' in hparams
else _RANDOM_INTERPOLATION,
338 if self.
prob < 1.0
and random.random() > self.
prob:
343 magnitude = min(_MAX_LEVEL, max(0, magnitude))
351 [(
'Equalize', 0.8, 1), (
'ShearY', 0.8, 4)],
352 [(
'Color', 0.4, 9), (
'Equalize', 0.6, 3)],
353 [(
'Color', 0.4, 1), (
'Rotate', 0.6, 8)],
354 [(
'Solarize', 0.8, 3), (
'Equalize', 0.4, 7)],
355 [(
'Solarize', 0.4, 2), (
'Solarize', 0.6, 2)],
356 [(
'Color', 0.2, 0), (
'Equalize', 0.8, 8)],
357 [(
'Equalize', 0.4, 8), (
'SolarizeAdd', 0.8, 3)],
358 [(
'ShearX', 0.2, 9), (
'Rotate', 0.6, 8)],
359 [(
'Color', 0.6, 1), (
'Equalize', 1.0, 2)],
360 [(
'Invert', 0.4, 9), (
'Rotate', 0.6, 0)],
361 [(
'Equalize', 1.0, 9), (
'ShearY', 0.6, 3)],
362 [(
'Color', 0.4, 7), (
'Equalize', 0.6, 0)],
363 [(
'Posterize', 0.4, 6), (
'AutoContrast', 0.4, 7)],
364 [(
'Solarize', 0.6, 8), (
'Color', 0.6, 9)],
365 [(
'Solarize', 0.2, 4), (
'Rotate', 0.8, 9)],
366 [(
'Rotate', 1.0, 7), (
'TranslateYRel', 0.8, 9)],
367 [(
'ShearX', 0.0, 0), (
'Solarize', 0.8, 4)],
368 [(
'ShearY', 0.8, 0), (
'Color', 0.6, 4)],
369 [(
'Color', 1.0, 0), (
'Rotate', 0.6, 2)],
370 [(
'Equalize', 0.8, 4), (
'Equalize', 0.0, 8)],
371 [(
'Equalize', 1.0, 4), (
'AutoContrast', 0.6, 2)],
372 [(
'ShearY', 0.4, 7), (
'SolarizeAdd', 0.6, 7)],
373 [(
'Posterize', 0.8, 2), (
'Solarize', 0.6, 10)],
374 [(
'Solarize', 0.6, 8), (
'Equalize', 0.6, 1)],
375 [(
'Color', 0.8, 6), (
'Rotate', 0.4, 5)],
377 pc = [[
AugmentOp(*a, hparams=hparams)
for a
in sp]
for sp
in policy]
385 [(
'Equalize', 0.8, 1), (
'ShearY', 0.8, 4)],
386 [(
'Color', 0.4, 9), (
'Equalize', 0.6, 3)],
387 [(
'Color', 0.4, 1), (
'Rotate', 0.6, 8)],
388 [(
'Solarize', 0.8, 3), (
'Equalize', 0.4, 7)],
389 [(
'Solarize', 0.4, 2), (
'Solarize', 0.6, 2)],
390 [(
'Color', 0.2, 0), (
'Equalize', 0.8, 8)],
391 [(
'Equalize', 0.4, 8), (
'SolarizeAdd', 0.8, 3)],
392 [(
'ShearX', 0.2, 9), (
'Rotate', 0.6, 8)],
393 [(
'Color', 0.6, 1), (
'Equalize', 1.0, 2)],
394 [(
'Invert', 0.4, 9), (
'Rotate', 0.6, 0)],
395 [(
'Equalize', 1.0, 9), (
'ShearY', 0.6, 3)],
396 [(
'Color', 0.4, 7), (
'Equalize', 0.6, 0)],
397 [(
'PosterizeIncreasing', 0.4, 6), (
'AutoContrast', 0.4, 7)],
398 [(
'Solarize', 0.6, 8), (
'Color', 0.6, 9)],
399 [(
'Solarize', 0.2, 4), (
'Rotate', 0.8, 9)],
400 [(
'Rotate', 1.0, 7), (
'TranslateYRel', 0.8, 9)],
401 [(
'ShearX', 0.0, 0), (
'Solarize', 0.8, 4)],
402 [(
'ShearY', 0.8, 0), (
'Color', 0.6, 4)],
403 [(
'Color', 1.0, 0), (
'Rotate', 0.6, 2)],
404 [(
'Equalize', 0.8, 4), (
'Equalize', 0.0, 8)],
405 [(
'Equalize', 1.0, 4), (
'AutoContrast', 0.6, 2)],
406 [(
'ShearY', 0.4, 7), (
'SolarizeAdd', 0.6, 7)],
407 [(
'PosterizeIncreasing', 0.8, 2), (
'Solarize', 0.6, 10)],
408 [(
'Solarize', 0.6, 8), (
'Equalize', 0.6, 1)],
409 [(
'Color', 0.8, 6), (
'Rotate', 0.4, 5)],
411 pc = [[
AugmentOp(*a, hparams=hparams)
for a
in sp]
for sp
in policy]
418 [(
'PosterizeOriginal', 0.4, 8), (
'Rotate', 0.6, 9)],
419 [(
'Solarize', 0.6, 5), (
'AutoContrast', 0.6, 5)],
420 [(
'Equalize', 0.8, 8), (
'Equalize', 0.6, 3)],
421 [(
'PosterizeOriginal', 0.6, 7), (
'PosterizeOriginal', 0.6, 6)],
422 [(
'Equalize', 0.4, 7), (
'Solarize', 0.2, 4)],
423 [(
'Equalize', 0.4, 4), (
'Rotate', 0.8, 8)],
424 [(
'Solarize', 0.6, 3), (
'Equalize', 0.6, 7)],
425 [(
'PosterizeOriginal', 0.8, 5), (
'Equalize', 1.0, 2)],
426 [(
'Rotate', 0.2, 3), (
'Solarize', 0.6, 8)],
427 [(
'Equalize', 0.6, 8), (
'PosterizeOriginal', 0.4, 6)],
428 [(
'Rotate', 0.8, 8), (
'Color', 0.4, 0)],
429 [(
'Rotate', 0.4, 9), (
'Equalize', 0.6, 2)],
430 [(
'Equalize', 0.0, 7), (
'Equalize', 0.8, 8)],
431 [(
'Invert', 0.6, 4), (
'Equalize', 1.0, 8)],
432 [(
'Color', 0.6, 4), (
'Contrast', 1.0, 8)],
433 [(
'Rotate', 0.8, 8), (
'Color', 1.0, 2)],
434 [(
'Color', 0.8, 8), (
'Solarize', 0.8, 7)],
435 [(
'Sharpness', 0.4, 7), (
'Invert', 0.6, 8)],
436 [(
'ShearX', 0.6, 5), (
'Equalize', 1.0, 9)],
437 [(
'Color', 0.4, 0), (
'Equalize', 0.6, 3)],
438 [(
'Equalize', 0.4, 7), (
'Solarize', 0.2, 4)],
439 [(
'Solarize', 0.6, 5), (
'AutoContrast', 0.6, 5)],
440 [(
'Invert', 0.6, 4), (
'Equalize', 1.0, 8)],
441 [(
'Color', 0.6, 4), (
'Contrast', 1.0, 8)],
442 [(
'Equalize', 0.8, 8), (
'Equalize', 0.6, 3)],
444 pc = [[
AugmentOp(*a, hparams=hparams)
for a
in sp]
for sp
in policy]
451 [(
'PosterizeIncreasing', 0.4, 8), (
'Rotate', 0.6, 9)],
452 [(
'Solarize', 0.6, 5), (
'AutoContrast', 0.6, 5)],
453 [(
'Equalize', 0.8, 8), (
'Equalize', 0.6, 3)],
454 [(
'PosterizeIncreasing', 0.6, 7), (
'PosterizeIncreasing', 0.6, 6)],
455 [(
'Equalize', 0.4, 7), (
'Solarize', 0.2, 4)],
456 [(
'Equalize', 0.4, 4), (
'Rotate', 0.8, 8)],
457 [(
'Solarize', 0.6, 3), (
'Equalize', 0.6, 7)],
458 [(
'PosterizeIncreasing', 0.8, 5), (
'Equalize', 1.0, 2)],
459 [(
'Rotate', 0.2, 3), (
'Solarize', 0.6, 8)],
460 [(
'Equalize', 0.6, 8), (
'PosterizeIncreasing', 0.4, 6)],
461 [(
'Rotate', 0.8, 8), (
'Color', 0.4, 0)],
462 [(
'Rotate', 0.4, 9), (
'Equalize', 0.6, 2)],
463 [(
'Equalize', 0.0, 7), (
'Equalize', 0.8, 8)],
464 [(
'Invert', 0.6, 4), (
'Equalize', 1.0, 8)],
465 [(
'Color', 0.6, 4), (
'Contrast', 1.0, 8)],
466 [(
'Rotate', 0.8, 8), (
'Color', 1.0, 2)],
467 [(
'Color', 0.8, 8), (
'Solarize', 0.8, 7)],
468 [(
'Sharpness', 0.4, 7), (
'Invert', 0.6, 8)],
469 [(
'ShearX', 0.6, 5), (
'Equalize', 1.0, 9)],
470 [(
'Color', 0.4, 0), (
'Equalize', 0.6, 3)],
471 [(
'Equalize', 0.4, 7), (
'Solarize', 0.2, 4)],
472 [(
'Solarize', 0.6, 5), (
'AutoContrast', 0.6, 5)],
473 [(
'Invert', 0.6, 4), (
'Equalize', 1.0, 8)],
474 [(
'Color', 0.6, 4), (
'Contrast', 1.0, 8)],
475 [(
'Equalize', 0.8, 8), (
'Equalize', 0.6, 3)],
477 pc = [[
AugmentOp(*a, hparams=hparams)
for a
in sp]
for sp
in policy]
482 hparams = _HPARAMS_DEFAULT
483 if name ==
'original':
485 elif name ==
'originalr':
492 assert False,
'Unknown AA policy (%s)' % name
503 if random.uniform(0, 1) > self.
gamma:
504 sub_policy = random.choice(self.
policy)
506 for op
in sub_policy:
515 Create a AutoAugment transform
516 :param config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by
517 dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr').
518 The remaining sections, not order sepecific determine
519 'mstd' - float std deviation of magnitude noise applied
520 Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
521 :param hparams: Other hparams (kwargs) for the AutoAugmentation scheme
522 :return: A PyTorch compatible Transform
524 config = config_str.split(
'-')
525 policy_name = config[0]
528 cs = re.split(
r'(\d.*)', c)
534 hparams.setdefault(
'magnitude_std', float(val))
536 assert False,
'Unknown AutoAugment config section'
560_RAND_INCREASING_TRANSFORMS = [
565 'PosterizeIncreasing',
566 'SolarizeIncreasing',
569 'ContrastIncreasing',
570 'BrightnessIncreasing',
571 'SharpnessIncreasing',
581_RAND_CHOICE_WEIGHTS_0 = {
585 'TranslateXRel': 0.1,
586 'TranslateYRel': 0.1,
589 'AutoContrast': 0.025,
601 transforms = transforms
or _RAND_TRANSFORMS
602 assert weight_idx == 0
603 rand_weights = _RAND_CHOICE_WEIGHTS_0
604 probs = [rand_weights[k]
for k
in transforms]
605 probs /= np.sum(probs)
610 hparams = hparams
or _HPARAMS_DEFAULT
611 transforms = transforms
or _RAND_TRANSFORMS
613 name, prob=0.5, magnitude=magnitude, hparams=hparams)
for name
in transforms]
617 def __init__(self, ops, num_layers=2, choice_weights=None):
624 ops = np.random.choice(
633 Create a RandAugment transform
634 :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
635 dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
636 sections, not order sepecific determine
637 'm' - integer magnitude of rand augment
638 'n' - integer num layers (number of transform ops selected per image)
639 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
640 'mstd' - float std deviation of magnitude noise applied
641 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
642 Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
643 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
644 :param hparams: Other hparams (kwargs) for the RandAugmentation scheme
645 :return: A PyTorch compatible Transform
647 magnitude = _MAX_LEVEL
650 transforms = _RAND_TRANSFORMS
651 config = config_str.split(
'-')
652 assert config[0] ==
'rand'
655 cs = re.split(
r'(\d.*)', c)
661 hparams.setdefault(
'magnitude_std', float(val))
664 transforms = _RAND_INCREASING_TRANSFORMS
668 num_layers = int(val)
670 weight_idx = int(val)
672 assert False,
'Unknown RandAugment config section'
673 ra_ops =
rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms)
675 return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
678_AUGMIX_TRANSFORMS = [
681 'ContrastIncreasing',
682 'BrightnessIncreasing',
683 'SharpnessIncreasing',
686 'PosterizeIncreasing',
687 'SolarizeIncreasing',
696 hparams = hparams
or _HPARAMS_DEFAULT
697 transforms = transforms
or _AUGMIX_TRANSFORMS
699 name, prob=1.0, magnitude=magnitude, hparams=hparams)
for name
in transforms]
704 Adapted and improved from impl here: https://github.com/google-research/augmix/blob/master/imagenet.py
705 From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty -
706 https://arxiv.org/abs/1912.02781
709 def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False):
724 return np.array(rws[::-1], dtype=np.float32)
731 img_orig = img.copy()
734 depth = self.
depth if self.
depth > 0
else np.random.randint(1, 4)
735 ops = np.random.choice(self.
ops, depth, replace=
True)
738 img_aug = op(img_aug)
739 img = Image.blend(img, img_aug, w)
746 img_shape = img.size[0], img.size[1], len(img.getbands())
747 mixed = np.zeros(img_shape, dtype=np.float32)
748 for mw
in mixing_weights:
749 depth = self.
depth if self.
depth > 0
else np.random.randint(1, 4)
750 ops = np.random.choice(self.
ops, depth, replace=
True)
753 img_aug = op(img_aug)
754 mixed += mw * np.asarray(img_aug, dtype=np.float32)
755 np.clip(mixed, 0, 255., out=mixed)
756 mixed = Image.fromarray(mixed.astype(np.uint8))
757 return Image.blend(img, mixed, m)
760 mixing_weights = np.float32(np.random.dirichlet([self.
alpha] * self.
width))
761 m = np.float32(np.random.beta(self.
alpha, self.
alpha))
770 """ Create AugMix PyTorch transform
771 :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
772 dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
773 sections, not order sepecific determine
774 'm' - integer magnitude (severity) of augmentation mix (default: 3)
775 'w' - integer width of augmentation chain (default: 3)
776 'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1)
777 'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0)
778 'mstd' - float std deviation of magnitude noise applied (default: 0)
779 Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2
780 :param hparams: Other hparams (kwargs) for the Augmentation transforms
781 :return: A PyTorch compatible Transform
788 config = config_str.split(
'-')
789 assert config[0] ==
'augmix'
792 cs = re.split(
r'(\d.*)', c)
798 hparams.setdefault(
'magnitude_std', float(val))
810 assert False,
'Unknown AugMix config section'
811 ops =
augmix_ops(magnitude=magnitude, hparams=hparams)
812 return AugMixAugment(ops, alpha=alpha, width=width, depth=depth, blended=blended)