Safemotion Lib
Loading...
Searching...
No Matches
autoaugment.py
Go to the documentation of this file.
1# encoding: utf-8
2"""
3@author: liaoxingyu
4@contact: sherlockliao01@gmail.com
5"""
6
7""" AutoAugment, RandAugment, and AugMix for PyTorch
8This code implements the searched ImageNet policies with various tweaks and improvements and
9does not include any of the search code.
10AA and RA Implementation adapted from:
11 https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
12AugMix adapted from:
13 https://github.com/google-research/augmix
14Papers:
15 AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501
16 Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172
17 RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
18 AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781
19Hacked together by Ross Wightman
20"""
21import math
22import random
23import re
24
25import PIL
26import numpy as np
27from PIL import Image, ImageOps, ImageEnhance
28
29_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
30
31_FILL = (128, 128, 128)
32
33# This signifies the max integer that the controller RNN could predict for the
34# augmentation scheme.
35_MAX_LEVEL = 10.
36
37_HPARAMS_DEFAULT = dict(
38 translate_const=57,
39 img_mean=_FILL,
40)
41
42_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
43
44
45def _interpolation(kwargs):
46 interpolation = kwargs.pop('resample', Image.BILINEAR)
47 if isinstance(interpolation, (list, tuple)):
48 return random.choice(interpolation)
49 else:
50 return interpolation
51
52
53def _check_args_tf(kwargs):
54 if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
55 kwargs.pop('fillcolor')
56 kwargs['resample'] = _interpolation(kwargs)
57
58
59def shear_x(img, factor, **kwargs):
60 _check_args_tf(kwargs)
61 return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
62
63
64def shear_y(img, factor, **kwargs):
65 _check_args_tf(kwargs)
66 return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
67
68
69def translate_x_rel(img, pct, **kwargs):
70 pixels = pct * img.size[0]
71 _check_args_tf(kwargs)
72 return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
73
74
75def translate_y_rel(img, pct, **kwargs):
76 pixels = pct * img.size[1]
77 _check_args_tf(kwargs)
78 return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
79
80
81def translate_x_abs(img, pixels, **kwargs):
82 _check_args_tf(kwargs)
83 return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
84
85
86def translate_y_abs(img, pixels, **kwargs):
87 _check_args_tf(kwargs)
88 return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
89
90
91def rotate(img, degrees, **kwargs):
92 _check_args_tf(kwargs)
93 if _PIL_VER >= (5, 2):
94 return img.rotate(degrees, **kwargs)
95 elif _PIL_VER >= (5, 0):
96 w, h = img.size
97 post_trans = (0, 0)
98 rotn_center = (w / 2.0, h / 2.0)
99 angle = -math.radians(degrees)
100 matrix = [
101 round(math.cos(angle), 15),
102 round(math.sin(angle), 15),
103 0.0,
104 round(-math.sin(angle), 15),
105 round(math.cos(angle), 15),
106 0.0,
107 ]
108
109 def transform(x, y, matrix):
110 (a, b, c, d, e, f) = matrix
111 return a * x + b * y + c, d * x + e * y + f
112
113 matrix[2], matrix[5] = transform(
114 -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
115 )
116 matrix[2] += rotn_center[0]
117 matrix[5] += rotn_center[1]
118 return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
119 else:
120 return img.rotate(degrees, resample=kwargs['resample'])
121
122
123def auto_contrast(img, **__):
124 return ImageOps.autocontrast(img)
125
126
127def invert(img, **__):
128 return ImageOps.invert(img)
129
130
131def equalize(img, **__):
132 return ImageOps.equalize(img)
133
134
135def solarize(img, thresh, **__):
136 return ImageOps.solarize(img, thresh)
137
138
139def solarize_add(img, add, thresh=128, **__):
140 lut = []
141 for i in range(256):
142 if i < thresh:
143 lut.append(min(255, i + add))
144 else:
145 lut.append(i)
146 if img.mode in ("L", "RGB"):
147 if img.mode == "RGB" and len(lut) == 256:
148 lut = lut + lut + lut
149 return img.point(lut)
150 else:
151 return img
152
153
154def posterize(img, bits_to_keep, **__):
155 if bits_to_keep >= 8:
156 return img
157 return ImageOps.posterize(img, bits_to_keep)
158
159
160def contrast(img, factor, **__):
161 return ImageEnhance.Contrast(img).enhance(factor)
162
163
164def color(img, factor, **__):
165 return ImageEnhance.Color(img).enhance(factor)
166
167
168def brightness(img, factor, **__):
169 return ImageEnhance.Brightness(img).enhance(factor)
170
171
172def sharpness(img, factor, **__):
173 return ImageEnhance.Sharpness(img).enhance(factor)
174
175
177 """With 50% prob, negate the value"""
178 return -v if random.random() > 0.5 else v
179
180
181def _rotate_level_to_arg(level, _hparams):
182 # range [-30, 30]
183 level = (level / _MAX_LEVEL) * 30.
184 level = _randomly_negate(level)
185 return level,
186
187
188def _enhance_level_to_arg(level, _hparams):
189 # range [0.1, 1.9]
190 return (level / _MAX_LEVEL) * 1.8 + 0.1,
191
192
194 # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
195 # range [0.1, 1.9]
196 level = (level / _MAX_LEVEL) * .9
197 level = 1.0 + _randomly_negate(level)
198 return level,
199
200
201def _shear_level_to_arg(level, _hparams):
202 # range [-0.3, 0.3]
203 level = (level / _MAX_LEVEL) * 0.3
204 level = _randomly_negate(level)
205 return level,
206
207
208def _translate_abs_level_to_arg(level, hparams):
209 translate_const = hparams['translate_const']
210 level = (level / _MAX_LEVEL) * float(translate_const)
211 level = _randomly_negate(level)
212 return level,
213
214
215def _translate_rel_level_to_arg(level, hparams):
216 # default range [-0.45, 0.45]
217 translate_pct = hparams.get('translate_pct', 0.45)
218 level = (level / _MAX_LEVEL) * translate_pct
219 level = _randomly_negate(level)
220 return level,
221
222
223def _posterize_level_to_arg(level, _hparams):
224 # As per Tensorflow TPU EfficientNet impl
225 # range [0, 4], 'keep 0 up to 4 MSB of original image'
226 # intensity/severity of augmentation decreases with level
227 return int((level / _MAX_LEVEL) * 4),
228
229
231 # As per Tensorflow models research and UDA impl
232 # range [4, 0], 'keep 4 down to 0 MSB of original image',
233 # intensity/severity of augmentation increases with level
234 return 4 - _posterize_level_to_arg(level, hparams)[0],
235
236
238 # As per original AutoAugment paper description
239 # range [4, 8], 'keep 4 up to 8 MSB of image'
240 # intensity/severity of augmentation decreases with level
241 return int((level / _MAX_LEVEL) * 4) + 4,
242
243
244def _solarize_level_to_arg(level, _hparams):
245 # range [0, 256]
246 # intensity/severity of augmentation decreases with level
247 return int((level / _MAX_LEVEL) * 256),
248
249
251 # range [0, 256]
252 # intensity/severity of augmentation increases with level
253 return 256 - _solarize_level_to_arg(level, _hparams)[0],
254
255
256def _solarize_add_level_to_arg(level, _hparams):
257 # range [0, 110]
258 return int((level / _MAX_LEVEL) * 110),
259
260
261LEVEL_TO_ARG = {
262 'AutoContrast': None,
263 'Equalize': None,
264 'Invert': None,
265 'Rotate': _rotate_level_to_arg,
266 # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
267 'Posterize': _posterize_level_to_arg,
268 'PosterizeIncreasing': _posterize_increasing_level_to_arg,
269 'PosterizeOriginal': _posterize_original_level_to_arg,
270 'Solarize': _solarize_level_to_arg,
271 'SolarizeIncreasing': _solarize_increasing_level_to_arg,
272 'SolarizeAdd': _solarize_add_level_to_arg,
273 'Color': _enhance_level_to_arg,
274 'ColorIncreasing': _enhance_increasing_level_to_arg,
275 'Contrast': _enhance_level_to_arg,
276 'ContrastIncreasing': _enhance_increasing_level_to_arg,
277 'Brightness': _enhance_level_to_arg,
278 'BrightnessIncreasing': _enhance_increasing_level_to_arg,
279 'Sharpness': _enhance_level_to_arg,
280 'SharpnessIncreasing': _enhance_increasing_level_to_arg,
281 'ShearX': _shear_level_to_arg,
282 'ShearY': _shear_level_to_arg,
283 'TranslateX': _translate_abs_level_to_arg,
284 'TranslateY': _translate_abs_level_to_arg,
285 'TranslateXRel': _translate_rel_level_to_arg,
286 'TranslateYRel': _translate_rel_level_to_arg,
287}
288
289NAME_TO_OP = {
290 'AutoContrast': auto_contrast,
291 'Equalize': equalize,
292 'Invert': invert,
293 'Rotate': rotate,
294 'Posterize': posterize,
295 'PosterizeIncreasing': posterize,
296 'PosterizeOriginal': posterize,
297 'Solarize': solarize,
298 'SolarizeIncreasing': solarize,
299 'SolarizeAdd': solarize_add,
300 'Color': color,
301 'ColorIncreasing': color,
302 'Contrast': contrast,
303 'ContrastIncreasing': contrast,
304 'Brightness': brightness,
305 'BrightnessIncreasing': brightness,
306 'Sharpness': sharpness,
307 'SharpnessIncreasing': sharpness,
308 'ShearX': shear_x,
309 'ShearY': shear_y,
310 'TranslateX': translate_x_abs,
311 'TranslateY': translate_y_abs,
312 'TranslateXRel': translate_x_rel,
313 'TranslateYRel': translate_y_rel,
314}
315
316
318
319 def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
320 hparams = hparams or _HPARAMS_DEFAULT
321 self.aug_fn = NAME_TO_OP[name]
322 self.level_fn = LEVEL_TO_ARG[name]
323 self.prob = prob
324 self.magnitude = magnitude
325 self.hparams = hparams.copy()
326 self.kwargs = dict(
327 fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
328 resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
329 )
330
331 # If magnitude_std is > 0, we introduce some randomness
332 # in the usually fixed policy and sample magnitude from a normal distribution
333 # with mean `magnitude` and std-dev of `magnitude_std`.
334 # NOTE This is my own hack, being tested, not in papers or reference impls.
335 self.magnitude_std = self.hparams.get('magnitude_std', 0)
336
337 def __call__(self, img):
338 if self.prob < 1.0 and random.random() > self.prob:
339 return img
340 magnitude = self.magnitude
341 if self.magnitude_std and self.magnitude_std > 0:
342 magnitude = random.gauss(magnitude, self.magnitude_std)
343 magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
344 level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
345 return self.aug_fn(img, *level_args, **self.kwargs)
346
347
349 # ImageNet v0 policy from TPU EfficientNet impl, cannot find a paper reference.
350 policy = [
351 [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
352 [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
353 [('Color', 0.4, 1), ('Rotate', 0.6, 8)],
354 [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
355 [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
356 [('Color', 0.2, 0), ('Equalize', 0.8, 8)],
357 [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
358 [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
359 [('Color', 0.6, 1), ('Equalize', 1.0, 2)],
360 [('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
361 [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
362 [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
363 [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
364 [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
365 [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
366 [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
367 [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
368 [('ShearY', 0.8, 0), ('Color', 0.6, 4)],
369 [('Color', 1.0, 0), ('Rotate', 0.6, 2)],
370 [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
371 [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
372 [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
373 [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], # This results in black image with Tpu posterize
374 [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
375 [('Color', 0.8, 6), ('Rotate', 0.4, 5)],
376 ]
377 pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
378 return pc
379
380
382 # ImageNet v0 policy from TPU EfficientNet impl, with variation of Posterize used
383 # in Google research implementation (number of bits discarded increases with magnitude)
384 policy = [
385 [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
386 [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
387 [('Color', 0.4, 1), ('Rotate', 0.6, 8)],
388 [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
389 [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
390 [('Color', 0.2, 0), ('Equalize', 0.8, 8)],
391 [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
392 [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
393 [('Color', 0.6, 1), ('Equalize', 1.0, 2)],
394 [('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
395 [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
396 [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
397 [('PosterizeIncreasing', 0.4, 6), ('AutoContrast', 0.4, 7)],
398 [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
399 [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
400 [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
401 [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
402 [('ShearY', 0.8, 0), ('Color', 0.6, 4)],
403 [('Color', 1.0, 0), ('Rotate', 0.6, 2)],
404 [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
405 [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
406 [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
407 [('PosterizeIncreasing', 0.8, 2), ('Solarize', 0.6, 10)],
408 [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
409 [('Color', 0.8, 6), ('Rotate', 0.4, 5)],
410 ]
411 pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
412 return pc
413
414
416 # ImageNet policy from https://arxiv.org/abs/1805.09501
417 policy = [
418 [('PosterizeOriginal', 0.4, 8), ('Rotate', 0.6, 9)],
419 [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
420 [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
421 [('PosterizeOriginal', 0.6, 7), ('PosterizeOriginal', 0.6, 6)],
422 [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
423 [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
424 [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
425 [('PosterizeOriginal', 0.8, 5), ('Equalize', 1.0, 2)],
426 [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
427 [('Equalize', 0.6, 8), ('PosterizeOriginal', 0.4, 6)],
428 [('Rotate', 0.8, 8), ('Color', 0.4, 0)],
429 [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
430 [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
431 [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
432 [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
433 [('Rotate', 0.8, 8), ('Color', 1.0, 2)],
434 [('Color', 0.8, 8), ('Solarize', 0.8, 7)],
435 [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
436 [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
437 [('Color', 0.4, 0), ('Equalize', 0.6, 3)],
438 [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
439 [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
440 [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
441 [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
442 [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
443 ]
444 pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
445 return pc
446
447
449 # ImageNet policy from https://arxiv.org/abs/1805.09501 with research posterize variation
450 policy = [
451 [('PosterizeIncreasing', 0.4, 8), ('Rotate', 0.6, 9)],
452 [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
453 [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
454 [('PosterizeIncreasing', 0.6, 7), ('PosterizeIncreasing', 0.6, 6)],
455 [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
456 [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
457 [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
458 [('PosterizeIncreasing', 0.8, 5), ('Equalize', 1.0, 2)],
459 [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
460 [('Equalize', 0.6, 8), ('PosterizeIncreasing', 0.4, 6)],
461 [('Rotate', 0.8, 8), ('Color', 0.4, 0)],
462 [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
463 [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
464 [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
465 [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
466 [('Rotate', 0.8, 8), ('Color', 1.0, 2)],
467 [('Color', 0.8, 8), ('Solarize', 0.8, 7)],
468 [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
469 [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
470 [('Color', 0.4, 0), ('Equalize', 0.6, 3)],
471 [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
472 [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
473 [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
474 [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
475 [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
476 ]
477 pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
478 return pc
479
480
481def auto_augment_policy(name="original"):
482 hparams = _HPARAMS_DEFAULT
483 if name == 'original':
484 return auto_augment_policy_original(hparams)
485 elif name == 'originalr':
486 return auto_augment_policy_originalr(hparams)
487 elif name == 'v0':
488 return auto_augment_policy_v0(hparams)
489 elif name == 'v0r':
490 return auto_augment_policy_v0r(hparams)
491 else:
492 assert False, 'Unknown AA policy (%s)' % name
493
494
496
497 def __init__(self, total_iter):
498 self.total_iter = total_iter
499 self.gamma = 0
501
502 def __call__(self, img):
503 if random.uniform(0, 1) > self.gamma:
504 sub_policy = random.choice(self.policy)
505 self.gamma = min(1.0, self.gamma + 1.0 / self.total_iter)
506 for op in sub_policy:
507 img = op(img)
508 return img
509 else:
510 return img
511
512
513def auto_augment_transform(config_str, hparams):
514 """
515 Create a AutoAugment transform
516 :param config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by
517 dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr').
518 The remaining sections, not order sepecific determine
519 'mstd' - float std deviation of magnitude noise applied
520 Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
521 :param hparams: Other hparams (kwargs) for the AutoAugmentation scheme
522 :return: A PyTorch compatible Transform
523 """
524 config = config_str.split('-')
525 policy_name = config[0]
526 config = config[1:]
527 for c in config:
528 cs = re.split(r'(\d.*)', c)
529 if len(cs) < 2:
530 continue
531 key, val = cs[:2]
532 if key == 'mstd':
533 # noise param injected via hparams for now
534 hparams.setdefault('magnitude_std', float(val))
535 else:
536 assert False, 'Unknown AutoAugment config section'
537 aa_policy = auto_augment_policy(policy_name)
538 return AutoAugment(aa_policy)
539
540
541_RAND_TRANSFORMS = [
542 'AutoContrast',
543 'Equalize',
544 'Invert',
545 'Rotate',
546 'Posterize',
547 'Solarize',
548 'SolarizeAdd',
549 'Color',
550 'Contrast',
551 'Brightness',
552 'Sharpness',
553 'ShearX',
554 'ShearY',
555 'TranslateXRel',
556 'TranslateYRel',
557 # 'Cutout' # NOTE I've implement this as random erasing separately
558]
559
560_RAND_INCREASING_TRANSFORMS = [
561 'AutoContrast',
562 'Equalize',
563 'Invert',
564 'Rotate',
565 'PosterizeIncreasing',
566 'SolarizeIncreasing',
567 'SolarizeAdd',
568 'ColorIncreasing',
569 'ContrastIncreasing',
570 'BrightnessIncreasing',
571 'SharpnessIncreasing',
572 'ShearX',
573 'ShearY',
574 'TranslateXRel',
575 'TranslateYRel',
576 # 'Cutout' # NOTE I've implement this as random erasing separately
577]
578
579# These experimental weights are based loosely on the relative improvements mentioned in paper.
580# They may not result in increased performance, but could likely be tuned to so.
581_RAND_CHOICE_WEIGHTS_0 = {
582 'Rotate': 0.3,
583 'ShearX': 0.2,
584 'ShearY': 0.2,
585 'TranslateXRel': 0.1,
586 'TranslateYRel': 0.1,
587 'Color': .025,
588 'Sharpness': 0.025,
589 'AutoContrast': 0.025,
590 'Solarize': .005,
591 'SolarizeAdd': .005,
592 'Contrast': .005,
593 'Brightness': .005,
594 'Equalize': .005,
595 'Posterize': 0,
596 'Invert': 0,
597}
598
599
600def _select_rand_weights(weight_idx=0, transforms=None):
601 transforms = transforms or _RAND_TRANSFORMS
602 assert weight_idx == 0 # only one set of weights currently
603 rand_weights = _RAND_CHOICE_WEIGHTS_0
604 probs = [rand_weights[k] for k in transforms]
605 probs /= np.sum(probs)
606 return probs
607
608
609def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
610 hparams = hparams or _HPARAMS_DEFAULT
611 transforms = transforms or _RAND_TRANSFORMS
612 return [AugmentOp(
613 name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]
614
615
617 def __init__(self, ops, num_layers=2, choice_weights=None):
618 self.ops = ops
619 self.num_layers = num_layers
620 self.choice_weights = choice_weights
621
622 def __call__(self, img):
623 # no replacement when using weighted choice
624 ops = np.random.choice(
625 self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights)
626 for op in ops:
627 img = op(img)
628 return img
629
630
631def rand_augment_transform(config_str, hparams):
632 """
633 Create a RandAugment transform
634 :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
635 dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
636 sections, not order sepecific determine
637 'm' - integer magnitude of rand augment
638 'n' - integer num layers (number of transform ops selected per image)
639 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
640 'mstd' - float std deviation of magnitude noise applied
641 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
642 Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
643 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
644 :param hparams: Other hparams (kwargs) for the RandAugmentation scheme
645 :return: A PyTorch compatible Transform
646 """
647 magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10)
648 num_layers = 2 # default to 2 ops per image
649 weight_idx = None # default to no probability weights for op choice
650 transforms = _RAND_TRANSFORMS
651 config = config_str.split('-')
652 assert config[0] == 'rand'
653 config = config[1:]
654 for c in config:
655 cs = re.split(r'(\d.*)', c)
656 if len(cs) < 2:
657 continue
658 key, val = cs[:2]
659 if key == 'mstd':
660 # noise param injected via hparams for now
661 hparams.setdefault('magnitude_std', float(val))
662 elif key == 'inc':
663 if bool(val):
664 transforms = _RAND_INCREASING_TRANSFORMS
665 elif key == 'm':
666 magnitude = int(val)
667 elif key == 'n':
668 num_layers = int(val)
669 elif key == 'w':
670 weight_idx = int(val)
671 else:
672 assert False, 'Unknown RandAugment config section'
673 ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms)
674 choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
675 return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
676
677
678_AUGMIX_TRANSFORMS = [
679 'AutoContrast',
680 'ColorIncreasing', # not in paper
681 'ContrastIncreasing', # not in paper
682 'BrightnessIncreasing', # not in paper
683 'SharpnessIncreasing', # not in paper
684 'Equalize',
685 'Rotate',
686 'PosterizeIncreasing',
687 'SolarizeIncreasing',
688 'ShearX',
689 'ShearY',
690 'TranslateXRel',
691 'TranslateYRel',
692]
693
694
695def augmix_ops(magnitude=10, hparams=None, transforms=None):
696 hparams = hparams or _HPARAMS_DEFAULT
697 transforms = transforms or _AUGMIX_TRANSFORMS
698 return [AugmentOp(
699 name, prob=1.0, magnitude=magnitude, hparams=hparams) for name in transforms]
700
701
703 """ AugMix Transform
704 Adapted and improved from impl here: https://github.com/google-research/augmix/blob/master/imagenet.py
705 From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty -
706 https://arxiv.org/abs/1912.02781
707 """
708
709 def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False):
710 self.ops = ops
711 self.alpha = alpha
712 self.width = width
713 self.depth = depth
714 self.blended = blended # blended mode is faster but not well tested
715
716 def _calc_blended_weights(self, ws, m):
717 ws = ws * m
718 cump = 1.
719 rws = []
720 for w in ws[::-1]:
721 alpha = w / cump
722 cump *= (1 - alpha)
723 rws.append(alpha)
724 return np.array(rws[::-1], dtype=np.float32)
725
726 def _apply_blended(self, img, mixing_weights, m):
727 # This is my first crack and implementing a slightly faster mixed augmentation. Instead
728 # of accumulating the mix for each chain in a Numpy array and then blending with original,
729 # it recomputes the blending coefficients and applies one PIL image blend per chain.
730 # TODO the results appear in the right ballpark but they differ by more than rounding.
731 img_orig = img.copy()
732 ws = self._calc_blended_weights(mixing_weights, m)
733 for w in ws:
734 depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
735 ops = np.random.choice(self.ops, depth, replace=True)
736 img_aug = img_orig # no ops are in-place, deep copy not necessary
737 for op in ops:
738 img_aug = op(img_aug)
739 img = Image.blend(img, img_aug, w)
740 return img
741
742 def _apply_basic(self, img, mixing_weights, m):
743 # This is a literal adaptation of the paper/official implementation without normalizations and
744 # PIL <-> Numpy conversions between every op. It is still quite CPU compute heavy compared to the
745 # typical augmentation transforms, could use a GPU / Kornia implementation.
746 img_shape = img.size[0], img.size[1], len(img.getbands())
747 mixed = np.zeros(img_shape, dtype=np.float32)
748 for mw in mixing_weights:
749 depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
750 ops = np.random.choice(self.ops, depth, replace=True)
751 img_aug = img # no ops are in-place, deep copy not necessary
752 for op in ops:
753 img_aug = op(img_aug)
754 mixed += mw * np.asarray(img_aug, dtype=np.float32)
755 np.clip(mixed, 0, 255., out=mixed)
756 mixed = Image.fromarray(mixed.astype(np.uint8))
757 return Image.blend(img, mixed, m)
758
759 def __call__(self, img):
760 mixing_weights = np.float32(np.random.dirichlet([self.alpha] * self.width))
761 m = np.float32(np.random.beta(self.alpha, self.alpha))
762 if self.blended:
763 mixed = self._apply_blended(img, mixing_weights, m)
764 else:
765 mixed = self._apply_basic(img, mixing_weights, m)
766 return mixed
767
768
769def augment_and_mix_transform(config_str, hparams):
770 """ Create AugMix PyTorch transform
771 :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
772 dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
773 sections, not order sepecific determine
774 'm' - integer magnitude (severity) of augmentation mix (default: 3)
775 'w' - integer width of augmentation chain (default: 3)
776 'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1)
777 'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0)
778 'mstd' - float std deviation of magnitude noise applied (default: 0)
779 Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2
780 :param hparams: Other hparams (kwargs) for the Augmentation transforms
781 :return: A PyTorch compatible Transform
782 """
783 magnitude = 3
784 width = 3
785 depth = -1
786 alpha = 1.
787 blended = False
788 config = config_str.split('-')
789 assert config[0] == 'augmix'
790 config = config[1:]
791 for c in config:
792 cs = re.split(r'(\d.*)', c)
793 if len(cs) < 2:
794 continue
795 key, val = cs[:2]
796 if key == 'mstd':
797 # noise param injected via hparams for now
798 hparams.setdefault('magnitude_std', float(val))
799 elif key == 'm':
800 magnitude = int(val)
801 elif key == 'w':
802 width = int(val)
803 elif key == 'd':
804 depth = int(val)
805 elif key == 'a':
806 alpha = float(val)
807 elif key == 'b':
808 blended = bool(val)
809 else:
810 assert False, 'Unknown AugMix config section'
811 ops = augmix_ops(magnitude=magnitude, hparams=hparams)
812 return AugMixAugment(ops, alpha=alpha, width=width, depth=depth, blended=blended)
__init__(self, ops, alpha=1., width=3, depth=-1, blended=False)
__init__(self, name, prob=0.5, magnitude=10, hparams=None)
__init__(self, ops, num_layers=2, choice_weights=None)
_posterize_increasing_level_to_arg(level, hparams)
augment_and_mix_transform(config_str, hparams)
_solarize_add_level_to_arg(level, _hparams)
_posterize_original_level_to_arg(level, _hparams)
auto_augment_transform(config_str, hparams)
translate_y_abs(img, pixels, **kwargs)
translate_x_abs(img, pixels, **kwargs)
rand_augment_transform(config_str, hparams)
rand_augment_ops(magnitude=10, hparams=None, transforms=None)
rotate(img, degrees, **kwargs)
translate_y_rel(img, pct, **kwargs)
shear_y(img, factor, **kwargs)
shear_x(img, factor, **kwargs)
posterize(img, bits_to_keep, **__)
solarize_add(img, add, thresh=128, **__)
_select_rand_weights(weight_idx=0, transforms=None)
augmix_ops(magnitude=10, hparams=None, transforms=None)
_solarize_increasing_level_to_arg(level, _hparams)
_enhance_increasing_level_to_arg(level, _hparams)
translate_x_rel(img, pct, **kwargs)